diff --git a/.gitignore b/.gitignore index 0247f55..ee77abb 100644 --- a/.gitignore +++ b/.gitignore @@ -11,9 +11,11 @@ target/packer/out .stignore pyrightconfig.json TODO -run.sh +run*.sh .env __pycache__/ *.py[co] client/warcupload/testdata upload-space/ +*.tar +PIPELINES diff --git a/client/worker/Dockerfile b/client/worker/Dockerfile index 98d169f..a07586d 100644 --- a/client/worker/Dockerfile +++ b/client/worker/Dockerfile @@ -12,7 +12,7 @@ RUN apt-get update \ && apt-get clean # 06fc825d3fbed9801110b2d3f562c44d72940862 is known to work. -RUN pip3 install --break-system-packages rethinkdb git+https://github.com/internetarchive/brozzler@06fc825d3fbed9801110b2d3f562c44d72940862 websockets doublethink yt-dlp aiofiles +RUN pip3 install --break-system-packages rethinkdb git+https://github.com/internetarchive/brozzler@06fc825d3fbed9801110b2d3f562c44d72940862 websockets doublethink yt-dlp aiofiles uuid_utils RUN pip3 install --break-system-packages --upgrade rethinkdb RUN mkdir /app diff --git a/client/worker/app.py b/client/worker/app.py index 1b14fd8..b254834 100755 --- a/client/worker/app.py +++ b/client/worker/app.py @@ -105,20 +105,24 @@ async def warcprox_cleanup(): except Exception: pass -async def run_job(ws: Websocket, full_job: dict, url: str, warc_prefix: str, ua: str, custom_js: typing.Optional[str], info_url: str): - tries = full_job['_current_attempt'] - id = full_job['id'] +async def run_job(ws: Websocket, full_job: dict, info_url: str): + job_id = full_job['job_id'] + page_id = full_job['page_id'] + attempt_id = full_job['attempt_id'] + assert "_" not in job_id + warc_prefix = "mnbot-brozzler-" + job_id.replace("-", "_") #dedup_bucket = f"dedup-{id}-{tries}" dedup_bucket = "" - stats_bucket = f"stats-{id}-{tries}" + stats_bucket = f"stats-{job_id}" job = Job( + attempt_id = attempt_id, full_job = full_job, - url = url, + url = full_job['payload'], warc_prefix = warc_prefix, dedup_bucket = dedup_bucket, stats_bucket = stats_bucket, - ua = ua, - custom_js = custom_js, + ua = full_job['settings']['ua'], + custom_js = full_job['settings']['custom_js'], cookie_jar = None, mnbot_info_url = info_url ) @@ -129,7 +133,10 @@ async def run_job(ws: Websocket, full_job: dict, url: str, warc_prefix: str, ua: PYTHON, os.path.join(os.path.dirname(sys.argv[0]), "browse.py"), str(pwrite), - id, # useful for ps + # These arguments are passed only because they are useful for ps + job_id, + page_id, + attempt_id, stdin = subprocess.PIPE, stdout = subprocess.PIPE, stderr = subprocess.STDOUT, @@ -174,16 +181,8 @@ async def run_job(ws: Websocket, full_job: dict, url: str, warc_prefix: str, ua: res = json.loads(res) type = res['type'] payload = res['payload'] - if type in ("status_code", "outlinks", "final_url", "requisites", "custom_js"): - await ws.store_result(id, type, tries, payload) - elif type == "screenshot": - # Don't tell the tracker to decode the thumbnail - # if there isn't a thumbnail - decode_fields = [k for k in ("full", "thumb") if payload[k]] - await ws.store_result(id, type, tries, payload, decode_fields) - elif type == "cjs_screenshot": - decode_fields = ["full"] - await ws.store_result(id, type, tries, payload, decode_fields) + if type in ("status_code", "outlinks", "final_url", "requisites", "custom_js", "screenshot", "cjs_screenshot"): + await ws.store_result(attempt_id, type, payload) elif type == "error": # Cancel tasks, since both stdout and pread are about to get closed. # Failing to do this results in a "Task exception was never retrieved" @@ -233,6 +232,7 @@ async def ping_occasionally(): pass await asyncio.sleep(15) + # Don't change this without changing the hardcoded slot=0 below and the hardcoded num_slots=1 in tracker.py. MAX_WORKERS = 1 workers: dict[asyncio.Task, tuple[TaskType, dict | None]] = dict() @@ -266,41 +266,36 @@ def handle_sigint(): logger.debug("not spinning up new item as we are pending a stop") continue logger.debug("spinning up worker") - resp = await ws.claim_item() + resp = await ws.claim_item(0) if resp: - item, info_url = resp - id = item['id'] - logger.info(f"Starting task {id}") - url = item['item'] - assert "_" not in id - prefix = "mnbot-brozzler-" + id.replace("-", "_") + claim, info_url = resp + print(claim, info_url) + attempt_id = claim['attempt_id'] + job_id = claim['job_id'] + page_id = claim['page_id'] + logger.info(f"Starting claim {attempt_id} (for {job_id} : {page_id}") task = asyncio.create_task(run_job( ws, - item, - url, - prefix, - item['metadata']['ua'], - item['metadata']['custom_js'], - info_url + claim, + info_url, )) - task.set_name(id) - workers[task] = (TaskType.ITEM, item) + task.set_name(attempt_id) + workers[task] = (TaskType.ITEM, claim) else: to_sleep = random.randint(10, 30) - logger.info(f"No items found, blocking this worker for {to_sleep} seconds.") + logger.info(f"No tasks found, blocking this worker for {to_sleep} seconds.") task = asyncio.create_task(asyncio.sleep(to_sleep)) workers[task] = (TaskType.SLEEP, None) done: set[asyncio.Task] = (await asyncio.wait(workers, return_when = asyncio.FIRST_COMPLETED))[0] for finished_task in done: logger.debug(f"checking finished task {finished_task}") - task_type, item = workers[finished_task] + task_type, task_claim = workers[finished_task] del workers[finished_task] if task_type != TaskType.ITEM: logger.debug("nevermind, not an item") continue - # item can't be None at this point - id = item['id'] - tries = item['_current_attempt'] + # claim can't be None at this point + attempt_id = task_claim['attempt_id'] try: _dedup_bucket, _stats_bucket = finished_task.result() except Exception as e: @@ -309,14 +304,14 @@ def handle_sigint(): fatal = e.fatal else: fatal = False - logger.exception(f"failed task {id}:") + logger.exception(f"failed task {attempt_id}:") fmt = io.StringIO() finished_task.print_stack(file = fmt) message = f"Caught exception!\n{fmt.getvalue()}" - await ws.fail_item(id, message, tries, fatal) + await ws.fail_item(attempt_id, message, fatal) else: logger.info(f"task {id} was successful!") - await ws.finish_item(id) + await ws.finish_item(attempt_id) logger.debug("creating cleanup task") task = asyncio.create_task(warcprox_cleanup()) workers[task] = (TaskType.CLEANUP, None) diff --git a/client/worker/browse.py b/client/worker/browse.py index 18875df..a59f8e1 100644 --- a/client/worker/browse.py +++ b/client/worker/browse.py @@ -342,12 +342,12 @@ def _brozzle(self) -> Result: ua = self._create_user_agent(version) # Write job info to the WARC - logger.debug("writing item info") + logger.debug("writing job info") self._write_warcprox_record( - "metadata:mnbot-job-metadata", + f"metadata:mnbot-metadata/{self.job.attempt_id}", "application/json", json.dumps({ - "job": self.job.full_job, + "claim": self.job.full_job, "version": VERSION, "browser": { "executable": self.chrome_exe, @@ -403,7 +403,7 @@ def _brozzle(self) -> Result: with self.websock_thread_lock: r = Result( - id = self.job.full_job['id'], + id = self.job.attempt_id, final_url = final_url, outlinks = list(outlinks), custom_js = custom_js_result, @@ -413,10 +413,11 @@ def _brozzle(self) -> Result: ) logger.debug("writing job result data") self._write_warcprox_record( - "metadata:mnbot-job-result", + f"metadata:mnbot-result/{self.job.attempt_id}", "application/json", json.dumps({ - "result": r.dict() + "result": r.dict(), + "attempt_id": self.job.attempt_id, }).encode(), self.job.warc_prefix ) @@ -445,17 +446,9 @@ def write_message(type, payload): write_message("requisites", [dataclasses.asdict(v) for v in result.requisites.values()]) write_message("status_code", result.status_code) - screenshot = browser.screenshot - thumbnail = browser.thumbnail - if screenshot: - screenshot = base64.b85encode(screenshot).decode() - if thumbnail: - thumbnail = base64.b85encode(thumbnail).decode() - if screenshot or thumbnail: - write_message("screenshot", { - "full": screenshot, - "thumb": thumbnail, - }) + if browser.screenshot: + screenshot = base64.b85encode(browser.screenshot).decode() + write_message("screenshot", screenshot) if result.custom_js_screenshot: write_message("cjs_screenshot", {"full": result.custom_js_screenshot}) diff --git a/client/worker/shared.py b/client/worker/shared.py index a4c81f0..7930ce7 100644 --- a/client/worker/shared.py +++ b/client/worker/shared.py @@ -7,7 +7,7 @@ # Update this whenever you make a change, cosmetic or not. # During development you can ignore it, but when you actually # push it to prod, it *must* be updated. -VERSION = "20260412.01" +VERSION = "20260626.01" DEBUG = os.environ.get("DEBUG") == "1" if DEBUG: @@ -19,6 +19,7 @@ @dataclasses.dataclass class Job: full_job: dict + attempt_id: str url: str warc_prefix: str dedup_bucket: str diff --git a/client/worker/tracker.py b/client/worker/tracker.py index 4f105a9..a4bbc5f 100644 --- a/client/worker/tracker.py +++ b/client/worker/tracker.py @@ -1,15 +1,12 @@ -### HEY YOU! Yeah, you! -### If you are making any change to the client, please -### update the version in meta.py. -### No two versions in prod should have the same version number. - ############################# NOTE! ############################### ### If you are making any change to the client, please ### ### update the version in shared.py. ### ### No two versions in prod should have the same version number.### ############################ THANKS! ############################## -import asyncio, logging, json +import asyncio +import json +import logging import typing from shared import VERSION @@ -19,6 +16,14 @@ import websockets from websockets.asyncio.client import connect +import uuid_utils.compat as uuid + +class CustomEncoder(json.JSONEncoder): + def default(self, obj): + if isinstance(obj, uuid.UUID): + return str(obj) + return super().default(obj) +encoder = CustomEncoder() class Websocket: def __init__(self, url: str, advisory_handler): @@ -32,13 +37,14 @@ async def _send_once(self, type, payload=None) -> tuple[int, dict]: async with self.lock: if not self.conn: logger.debug("creating connection") - self.conn = await connect(self.url) - await self.conn.send(json.dumps({"v": VERSION})) + conn = await connect(self.url) + await conn.send(encoder.encode({"v": VERSION, "p": 2, "num_slots": 1})) + self.conn = conn self.seq += 1 seq = self.seq message = {"type": type, "request": payload, "seq": seq} - logger.debug(f"sending message (keys: {list(message.keys())})") - await self.conn.send(json.dumps(message)) + logger.debug(f"sending {type} message (seq = {seq})") + await self.conn.send(encoder.encode(message)) while True: resp = json.loads(await self.conn.recv()) if resp['seq'] is None: # Advisory @@ -79,41 +85,39 @@ async def _send(self, type, payload=None): await asyncio.sleep(sleep) tries += 1 - async def claim_item(self) -> typing.Optional[tuple[dict, str]]: - status, resp = await self._send("Item:claim", {"pipeline_type": "brozzler"}) + async def claim_item(self, slot: int) -> typing.Optional[tuple[dict, str]]: + status, resp = await self._send("Item:claim", {"slot": slot}) if status != 200: raise RuntimeError(f"Bad response from server: {status} {resp}") if resp['item']: return resp['item'], resp['info_url'] return None - async def fail_item(self, id: str, reason: str, tries: int, fatal: bool): + async def fail_item(self, attempt_id: str, reason: str, fatal: bool): status, resp = await self._send( "Item:fail", { - "id": id, + "attempt_id": attempt_id, "message": reason, - "attempt": tries, - "fatal": fatal + "fatal": fatal, } ) if status != 204: raise RuntimeError(f"Bad response from server: {status} {resp}") - async def finish_item(self, id: str): - status, resp = await self._send("Item:finish", {"id": id}) + async def finish_item(self, attempt_id: str): + status, resp = await self._send("Item:finish", {"attempt_id": attempt_id}) if status != 204: raise RuntimeError(f"Bad response from server: {status} {resp}") - async def store_result(self, id: str, result_type: str, tries: int, result, decode_fields = None): + async def store_result(self, attempt_id: str, result_type: str, result): + result_id = uuid.uuid7() pl = { - "result_type": result_type, - "attempt": tries, - "result": result, - "id": id + "result_id": result_id, + "attempt_id": attempt_id, + "type": result_type, + "payload": result, } - if decode_fields: - pl['decode_fields'] = decode_fields status, resp = await self._send("Item:store", pl) if status != 201: raise RuntimeError(f"Bad response from server: {status} {resp}") diff --git a/test_db.py b/test_db.py new file mode 100644 index 0000000..2a70923 --- /dev/null +++ b/test_db.py @@ -0,0 +1,827 @@ +import sys +import asyncio +import os +import os.path +import secrets +import string +import dataclasses + +import sqlalchemy, sqlalchemy.ext.asyncio, sqlalchemy.exc, sqlalchemy.dialects.postgresql.asyncpg +import asyncpg + +import pytest +import pytest_asyncio + +from tracker.common import db, model +from tracker.scripts.__main__ import add_pipeline, create + +@dataclasses.dataclass +class CreateArgs: + uri: str + tries: int + +@dataclasses.dataclass +class AddPipelineArgs: + uri: str + id: str + matchonly: bool + +CREATED = asyncio.Event() + +phase_report_key = pytest.StashKey[dict[str, pytest.CollectReport]]() +# https://docs.pytest.org/en/latest/example/simple.html#making-test-result-information-available-in-fixtures +# Allows fixtures to tell whether the test passed - currently unused, but may be useful +@pytest.hookimpl(wrapper=True, tryfirst=True) +def pytest_runtest_makereport(item, call): + rep = yield + + # store test results for each phase of a call, which can + # be "setup", "call", "teardown" + item.stash.setdefault(phase_report_key, {})[rep.when] = rep + + return rep + +@pytest_asyncio.fixture() +async def engine(request): + uri = os.environ['MNBOT_DATABASE_URI'] + root_engine = sqlalchemy.ext.asyncio.create_async_engine(uri, isolation_level = "AUTOCOMMIT") + async with root_engine.connect() as conn: + await conn.execute(sqlalchemy.text("DROP SCHEMA IF EXISTS public CASCADE")) + await conn.execute(sqlalchemy.text("CREATE SCHEMA public AUTHORIZATION pg_database_owner")) + await conn.execute(sqlalchemy.text("GRANT USAGE ON SCHEMA public TO PUBLIC")) + await conn.execute(sqlalchemy.text("GRANT ALL ON SCHEMA public TO pg_database_owner")) + await conn.commit() + await create(CreateArgs(uri, 2)) + + engine = await db.create_engine(uri) + yield engine + await engine.dispose() + + await root_engine.dispose() + +def _test(f): + return pytest.mark.asyncio(f) + +@_test +async def test_authenticate(engine: sqlalchemy.ext.asyncio.AsyncEngine): + """ + Tests that pipeline authentication works. + """ + async with engine.connect() as conn: + q = db.Connection(conn) + await q.create_pipeline("foo", False, "password") + await q.create_pipeline("bar", True, "password1") + foo = await q.pipeline("foo") + bar = await q.pipeline("bar") + + with pytest.raises(db.AuthenticationFailure): + await foo.authenticate("password1") + with pytest.raises(db.AuthenticationFailure): + await bar.authenticate("password") + await foo.authenticate("password") + await bar.authenticate("password1") + +@_test +async def test_get_job_counts(engine: sqlalchemy.ext.asyncio.AsyncEngine): + """ + Tests that the active job count works. + """ + async with engine.connect() as conn: + q = db.Connection(conn) + async def make_job(): + await q.create_jobs([db.JobCreation(db.generate_id(), model.JobType.BROZZLER, "foo", {}, "")]) + + assert (await q.get_job_counts()) == {} + await make_job() + assert (await q.get_job_counts()) == {model.JobType.BROZZLER: 1} + await make_job() + await make_job() + assert (await q.get_job_counts()) == {model.JobType.BROZZLER: 3} + +async def make_job(q, status = model.JobStatus.ACTIVE, concurrency = 0, nice = 0, tag = None, type = model.JobType.BROZZLER, depth = None): + id = db.generate_id() + await q.create_jobs([db.JobCreation(id, type, "foo", {}, "", status, concurrency, nice, tag, None, depth)]) + return id + +async def make_pages(q, job_id, *payloads, parent_page = None): + pagecs = [] + ids = [] + for page in payloads: + id = db.generate_id() + ids.append(id) + pagecs.append(db.PageCreation(id, page, parent_page)) + pages = await q.create_pages(job_id, pagecs) + return [pages[page_id] for page_id in ids] + +async def check_claim(q, expected_id, matchonly = False): + pipe = await q.pipeline("pipe") + res = await pipe._find_claimable_job(model.JobType.BROZZLER, matchonly, None) + assert res == expected_id + +async def check_all_claims(q, *expected_ids, type = model.JobType.BROZZLER): + res = await q.get_all_claimable_jobs(type) + assert tuple(res) == expected_ids + +@_test +async def test_job_order(engine: sqlalchemy.ext.asyncio.AsyncEngine): + """ + Tests that jobs are dequeued in the correct order. + """ + async with engine.connect() as conn: + q = db.Connection(conn) + await q.create_pipeline("pipe", False, "password") + await check_claim(q, None) + # Concurrency 0 or status != active = no claim + await make_job(q) + await make_job(q, status = model.JobStatus.DRAINING, concurrency = 1) + await check_claim(q, None) + await check_all_claims(q) + + # The pipeline has no tags associated with it, so tags should not work + idtag = await make_job(q, concurrency = 1, tag = "baz", nice = 999) + await check_claim(q, None) + # However, it should show up in check_all_claims, which has no restrictions on tags + await check_all_claims(q, idtag) + + id0 = await make_job(q, concurrency = 1, nice = 1) + await check_claim(q, id0) + id1 = await make_job(q, concurrency = 1) + await check_claim(q, id1) + id2 = await make_job(q, concurrency = 7) + await check_claim(q, id1) + await check_all_claims(q, id1, id2, id0, idtag) + id3 = await make_job(q, concurrency = 1, nice = -1) + await check_all_claims(q, id3, id1, id2, id0, idtag) + +@_test +async def test_tags(engine: sqlalchemy.ext.asyncio.AsyncEngine): + """ + Tests that the tagging system works properly. + """ + async with engine.connect() as conn: + q = db.Connection(conn) + await q.create_pipeline("pipe", True, "password") + pipe = await q.pipeline("pipe") + await pipe.create_tags("foo", "bar") + # Ensure duplicate tags are silently ignored and don't throw an error + await pipe.create_tags("foo", "baz") + + # This tag is not assigned to the pipeline, so it shouldn't work + id1 = await make_job(q, concurrency = 1, tag = "quux") + await check_claim(q, None) + await check_claim(q, None, matchonly = True) + await check_all_claims(q, id1) + + id2 = await make_job(q, concurrency = 1, tag = "foo") + await check_claim(q, id2) + await check_claim(q, id2, matchonly = True) + await check_all_claims(q, id1, id2) + + # This job has no tag, so matchonly should ignore it + id3 = await make_job(q, concurrency = 1, nice = -1) + await check_claim(q, id3) + await check_claim(q, id2, matchonly = True) + await check_all_claims(q, id3, id1, id2) + + # Ensure that NONE jobs aren't claimed, even when they have a matching tag + id4 = await make_job(q, type = model.JobType.NONE, tag = "foo", nice = -10, concurrency = 1) + await check_claim(q, id3) + await check_claim(q, id2, matchonly = True) + await check_all_claims(q, id3, id1, id2) + await check_all_claims(q, id4, type = model.JobType.NONE) + + # Remove tag and see if id2 disappears from matchonly + await pipe.remove_tags("foo") + await check_claim(q, id3) + await check_claim(q, None, matchonly = True) + await check_all_claims(q, id3, id1, id2) + +@_test +async def test_claiming(engine: sqlalchemy.ext.asyncio.AsyncEngine): + """ + Tests set_claim. + """ + async with engine.connect() as conn: + q = db.Connection(conn) + job = await make_job(q, concurrency = 1) + job2 = await make_job(q, concurrency = 1) + async def check_claim_count(expected1, expected2): + q = sqlalchemy.select(model.jobs.c.active_claims).where(model.jobs.c.job_id == job) + val = (await conn.execute(q)).first()[0] + q = sqlalchemy.select(model.jobs.c.active_claims).where(model.jobs.c.job_id == job2) + val2 = (await conn.execute(q)).first()[0] + assert (val, val2) == (expected1, expected2) + + await q.create_pipeline("pipe", False, "password") + pipe = await q.pipeline("pipe", 0, 1) + + await check_claim_count(0, 0) + await pipe._set_claim(0, job) + await check_claim_count(1, 0) + await pipe._set_claim(0, job2) + await check_claim_count(0, 1) + await pipe._set_claim(1, job) + await check_claim_count(1, 1) + await pipe._set_claim(0, job) + await check_claim_count(2, 0) + await pipe._set_claim(0, None) + await check_claim_count(1, 0) + await pipe._set_claim(0, None) + await check_claim_count(1, 0) + await pipe._set_claim(1, None) + await check_claim_count(0, 0) + + with pytest.raises(db.NoSuchPipelineError): + await pipe._set_claim(2, job) + +@_test +async def test_pipeline(engine: sqlalchemy.ext.asyncio.AsyncEngine): + """ + Tests basic pipeline lifecycle. + """ + async with engine.connect() as conn: + q = db.Connection(conn) + # Create pipelines - one matchonly, one not + await q.create_pipeline("pipe", False, "password") + await q.create_pipeline("pipe_mo", True, "password") + # Register two slots for each + pipe = await q.pipeline("pipe", 0, 1) + pipe_mo = await q.pipeline("pipe_mo", 0, 1) + # Add a tag to both + await pipe.create_tags("tag") + await pipe_mo.create_tags("tag") + + # ... Ok, time for some testing! + job1 = await make_job(q, concurrency = 3) + pageids = await make_pages(q, job1, "one", "two", "three", "four") + pages = [] + pages.extend(( + # one + await pipe.find_claim_page("", 0, model.JobType.BROZZLER), + # two + await pipe.find_claim_page("", 1, model.JobType.BROZZLER), + # None + await pipe_mo.find_claim_page("", 1, model.JobType.BROZZLER), + )) + await conn.execute(sqlalchemy.update(model.jobs).where(model.jobs.c.job_id == job1).values(tag = "tag")) + pages.extend(( + # three + await pipe_mo.find_claim_page("", 1, model.JobType.BROZZLER), + # None (reached concurrency limit) + await pipe_mo.find_claim_page("", 0, model.JobType.BROZZLER), + )) + print(pages) + found_pages = [page.page_id if page else None for page in pages] + found_payloads = [page.payload if page else None for page in pages] + expected_pages = pageids[0:2] + [None, pageids[2], None] + expected_payloads = ["one", "two", None, "three", None] + assert found_pages == expected_pages + assert found_payloads == expected_payloads + +@_test +async def test_retries(engine: sqlalchemy.ext.asyncio.AsyncEngine): + """ + Tests the retrying/finishing system. + """ + async with engine.connect() as conn: + q = db.Connection(conn) + await q.create_pipeline("pipe", False, "password") + pipe = await q.pipeline("pipe", 0, 1) + job_id = await make_job(q, concurrency = 2) + await make_pages(q, job_id, "one", "two", "three", "four") + + # Claim two pages, finish one of them + claim1 = await pipe.find_claim_page("", 0, model.JobType.BROZZLER) + assert claim1 and claim1.payload == "one" + claim2 = await pipe.find_claim_page("", 1, model.JobType.BROZZLER) + assert claim2 and claim2.payload == "two" + await pipe.finish_attempt(claim1.attempt_id) + # Claim a third page, fail it non-fatally. It should be returned to the queue + claim3 = await pipe.find_claim_page("", 0, model.JobType.BROZZLER) + assert claim3 and claim3.payload == "three" + # When it is returned to the queue it should have one try remaining + assert (await pipe.fail_attempt(claim3.attempt_id, "error", False)) == 1 + # Claim a fourth page, failing it fatally + claim4 = await pipe.find_claim_page("", 0, model.JobType.BROZZLER) + assert claim4 and claim4.payload == "four" + # Should thus have no tries remaining + assert (await pipe.fail_attempt(claim4.attempt_id, "error", True)) == 0 + # Ensure claim 3 was recycled back into the queue, and that max tries is taken into account. + claim3_2 = await pipe.find_claim_page("", 0, model.JobType.BROZZLER) + assert claim3_2 and claim3_2.payload == "three" + assert (await pipe.fail_attempt(claim3_2.attempt_id, "error", False)) == 0 + with pytest.raises(db.JobExhausted): + await pipe.find_claim_page("", 0, model.JobType.BROZZLER) + # Ordinarily we would now recalculate the job status, but not in this test + # Ensure that adding retries manually works as intended + await q.retry_page(claim3_2.page_id, 1) + claim3_4 = await pipe.find_claim_page("", 0, model.JobType.BROZZLER) + assert claim3_4 and claim3_4.payload == "three" + assert (await pipe.fail_attempt(claim3_4.attempt_id, "error", False)) == 0 + with pytest.raises(db.JobExhausted): + await pipe.find_claim_page("", 0, model.JobType.BROZZLER) + +async def check_job_status(queue: db.Connection, job_id, expected_status): + q = sqlalchemy.select(model.jobs.c.status).where(model.jobs.c.job_id == job_id) + res = (await queue.conn.execute(q)).first() + assert res and res[0] == expected_status + +@_test +async def test_finishing(engine: sqlalchemy.ext.asyncio.AsyncEngine): + """ + Tests the update_job_status function. + """ + async with engine.connect() as conn: + queue = db.Connection(conn) + await queue.create_pipeline("pipe", False, "password") + pipe = await queue.pipeline("pipe", 0, 1) + # Create two jobs with some pages + job1 = await make_job(queue, concurrency = 2) + await check_job_status(queue, job1, model.JobStatus.ACTIVE) + job2 = await make_job(queue, concurrency = 1) + await check_job_status(queue, job2, model.JobStatus.ACTIVE) + with pytest.raises(AssertionError): + await check_job_status(queue, job2, model.JobStatus.DRAINING) + await make_pages(queue, job1, "1one", "1two", "1three") + await make_pages(queue, job2, "2one") + + # Start claiming from job 1 + claim1_1 = await pipe.find_claim_page("", 0, model.JobType.BROZZLER) + assert claim1_1 and claim1_1.payload == "1one" + await pipe.finish_attempt(claim1_1.attempt_id) + assert (await pipe.parent.update_job_status(job1)) == model.JobStatus.ACTIVE + await check_job_status(queue, job1, model.JobStatus.ACTIVE) + # Fail page non-fatally + claim1_2 = await pipe.find_claim_page("", 0, model.JobType.BROZZLER) + assert claim1_2 and claim1_2.payload == "1two" + await pipe.fail_attempt(claim1_2.attempt_id, "", False) + assert (await pipe.parent.update_job_status(job1)) == model.JobStatus.ACTIVE + await check_job_status(queue, job1, model.JobStatus.ACTIVE) + # Claim third page but don't fail it yet + claim1_3 = await pipe.find_claim_page("", 0, model.JobType.BROZZLER) + # It should still be active... + assert (await pipe.parent.update_job_status(job1)) == model.JobStatus.ACTIVE + assert claim1_3 and claim1_3.payload == "1three" + # but if we reclaim the failed page, it should be DRAINING + claim1_2_2 = await pipe.find_claim_page("", 1, model.JobType.BROZZLER) + assert claim1_2_2 and claim1_2_2.payload == "1two" + assert (await pipe.parent.update_job_status(job1)) == model.JobStatus.DRAINING + # Same goes for if we finish one of them (but not both) + await pipe.fail_attempt(claim1_2_2.attempt_id, "", False) + assert (await pipe.parent.update_job_status(job1)) == model.JobStatus.DRAINING + await check_job_status(queue, job1, model.JobStatus.DRAINING) + # And if we finish the other, we're done :-) + await pipe.finish_attempt(claim1_3.attempt_id) + assert (await pipe.parent.update_job_status(job1)) == model.JobStatus.DONE + await check_job_status(queue, job1, model.JobStatus.DONE) + + await check_job_status(queue, job2, model.JobStatus.ACTIVE) + claim2_1 = await pipe.find_claim_page("", 0, model.JobType.BROZZLER) + assert claim2_1 and claim2_1.payload == "2one" + await pipe.fail_attempt(claim2_1.attempt_id, "", True) + assert (await pipe.parent.update_job_status(job2)) == model.JobStatus.DONE + await check_job_status(queue, job2, model.JobStatus.DONE) + + # Finally, adding more pages should make it ACTIVE again. + await make_pages(queue, job2, "hi") + assert (await pipe.parent.update_job_status(job2)) == model.JobStatus.ACTIVE + +@_test +async def test_update_job_status_with_abort(engine: sqlalchemy.ext.asyncio.AsyncEngine): + async with engine.connect() as conn: + q = db.Connection(conn) + await q.create_pipeline("pipe", False, "password") + pipe = await q.pipeline("pipe", 0) + job1 = await make_job(q, concurrency = 1) + await make_pages(q, job1, "a") + + await q.abort_job(job1) + assert (await pipe.find_claim_page("", 0, model.JobType.BROZZLER)) is None + assert (await q.update_job_status(job1) == model.JobStatus.ABORTED) + assert (await q.update_job_status(job1, True) == model.JobStatus.ACTIVE) + # Claim page, ensure it is be set to DRAINING + res = await pipe.find_claim_page("", 0, model.JobType.BROZZLER) + assert res and res.payload == "a" + await q.abort_job(job1) + assert (await q.update_job_status(job1) == model.JobStatus.ABORTED) + assert (await q.update_job_status(job1, True) == model.JobStatus.DRAINING) + await q.abort_job(job1) + await pipe.finish_attempt(res.attempt_id) + assert (await q.update_job_status(job1) == model.JobStatus.ABORTED) + assert (await q.update_job_status(job1, True) == model.JobStatus.DONE) + +@_test +async def test_update_job_status_with_depth(engine: sqlalchemy.ext.asyncio.AsyncEngine): + async with engine.connect() as conn: + q = db.Connection(conn) + await q.create_pipeline("pipe", False, "password") + pipe = await q.pipeline("pipe", 0) + + job1 = await make_job(q, concurrency = 1, depth = 0) + (a,) = await make_pages(q, job1, "a") + (b,) = await make_pages(q, job1, "b", parent_page = a) + assert (await q.update_job_status(job1)) == model.JobStatus.ACTIVE + claim = (await pipe.find_claim_page("", 0, model.JobType.BROZZLER)) + assert claim and claim.page_id == a + assert (await q.update_job_status(job1)) == model.JobStatus.DRAINING + await pipe.finish_attempt(claim.attempt_id) + # At this point, there is still an item, but it is out of scope and so the job is done. + assert (await q.update_job_status(job1)) == model.JobStatus.DONE + +@_test +async def test_update_job_status_with_skip(engine: sqlalchemy.ext.asyncio.AsyncEngine): + async with engine.connect() as conn: + q = db.Connection(conn) + await q.create_pipeline("pipe", False, "password") + pipe = await q.pipeline("pipe", 0) + + job1 = await make_job(q, concurrency = 1) + await q.create_job_rule(job1, None, db.JobRule("", {"skip": True})) + await make_pages(q, job1, "foo") + # update_job_status will not be aware of the SKIP rule + assert (await q.update_job_status(job1)) == model.JobStatus.ACTIVE + # find_claim_page will be, though, and the claim will fail + try: + await pipe.find_claim_page("", 0, model.JobType.BROZZLER) + except db.JobExhausted as e: + assert e.job_id == job1 + else: + raise AssertionError("Exception not raised") + # Now that the page has been marked SKIPPED, update_job_status should function correctly + assert (await q.update_job_status(job1)) == model.JobStatus.DONE + assert (await pipe.find_claim_page("", 0, model.JobType.BROZZLER)) is None + +async def _get_job_depth(q, job_id, pipe): + res = await q.conn.execute(sqlalchemy.select(pipe.parent._job_depth(job_id))) + return res.first() + +async def _assert_eligible_jobs(pipe: db.Pipeline, job_id: model.UUID, expected_eligible: list, expected_ineligible: list): + """ + Asserts that all jobs are considered eligible or ineligible. (The order of dequeuing is not checked.) + """ + discovered_e = [] + discovered_i = [] + for page in expected_eligible: + res = await pipe._find_claimable_page(job_id, _page_id = page) + discovered_e.append(res.page_id if res else f"ineligible[{page}]") + for page in expected_ineligible: + res = await pipe._find_claimable_page(job_id, _page_id = page) + discovered_i.append(page if not res else f"eligible[{page}]") + assert discovered_e == expected_eligible + assert discovered_i == expected_ineligible + +async def _get_page_depths(pipe, *page_ids): + depths = [] + for page_id in page_ids: + res = await pipe.parent.conn.execute(sqlalchemy.select(pipe.parent._page_depth(page_id))) + depths.append(res.first()[0]) + return depths + +@_test +async def test_depth_tracking(engine: sqlalchemy.ext.asyncio.AsyncEngine): + """ + Tests whether the relations system works as expected. + """ + async with engine.connect() as conn: + q = db.Connection(conn) + await q.create_pipeline("pipe", False, "password") + pipe = await q.pipeline("pipe", 0) + + job1 = await make_job(q, concurrency = 1, depth = None) + root1, root2 = await make_pages(q, job1, "foo", "bar", parent_page = None) + page11 = (await make_pages(q, job1, "baz", parent_page = root1))[0] + page21, page22 = await make_pages(q, job1, "a", "b", parent_page = root2) + + res = await _get_job_depth(q, job1, pipe) + assert res and res[0] is None + assert (await _get_page_depths(pipe, root1, root2, page11, page21, page22)) == [0, 0, 1, 1, 1] + await _assert_eligible_jobs(pipe, job1, [root1, root2, page11, page21, page22], []) + await conn.execute(sqlalchemy.update(model.jobs).where(model.jobs.c.job_id == job1).values(depth = -1)) + res = await _get_job_depth(q, job1, pipe) + assert res and res[0] == -1 + await _assert_eligible_jobs(pipe, job1, [], [root1, root2, page11, page21, page22]) + with pytest.raises(AssertionError): + # Who tests the tests? + await _assert_eligible_jobs(pipe, job1, [root1], [root2, page11, page21, page22]) + await conn.execute(sqlalchemy.update(model.jobs).where(model.jobs.c.job_id == job1).values(depth = 1)) + await _assert_eligible_jobs(pipe, job1, [root1, root2, page11, page21, page22], []) + # Add new page of depth 2, which should be ineligible + (page211,) = await make_pages(q, job1, "c", parent_page = page21) + await _assert_eligible_jobs(pipe, job1, [root1, root2, page11, page21, page22], [page211]) + assert (await _get_page_depths(pipe, page211)) == [2] + # Add new path for page211 of depth 1, which should make it eligible + page23, page24 = await make_pages(q, job1, "c", "d", parent_page = root2) + assert page23 == page211 + assert (await _get_page_depths(pipe, page211, page24)) == [1, 1] + await _assert_eligible_jobs(pipe, job1, [root1, root2, page11, page21, page22, page23, page24], []) + # Add new path of depth 4, which should have no effect + (page2111,) = await make_pages(q, job1, "c", parent_page = page211) + assert page2111 == page23 + assert (await _get_page_depths(pipe, page211)) == [1] + await _assert_eligible_jobs(pipe, job1, [root1, root2, page11, page21, page22, page23, page24], []) + # Add some cycles to ensure nothing hangs + (root1a,) = await make_pages(q, job1, "foo", parent_page = root1) + assert root1a == root1 + assert (await _get_page_depths(pipe, root1)) == [0] + await _assert_eligible_jobs(pipe, job1, [root1, root2, page11, page21, page22, page23, page24], []) + (root1b,) = await make_pages(q, job1, "foo", parent_page = page11) + assert root1b == root1 + assert (await _get_page_depths(pipe, root1)) == [0] + await _assert_eligible_jobs(pipe, job1, [root1, root2, page11, page21, page22, page23, page24], []) + +@_test +async def test_attempt_id_to_job_id(engine: sqlalchemy.ext.asyncio.AsyncEngine): + """ + Tests whether attempt_id_to_job_id functions as expected. + """ + async with engine.connect() as conn: + q = db.Connection(conn) + await q.create_pipeline("pipe", False, "password") + pipe = await q.pipeline("pipe", 0) + + job1 = await make_job(q, concurrency = 1) + await make_pages(q, job1, "item1") + cl = await pipe.find_claim_page("", 0, model.JobType.BROZZLER) + assert cl and (await pipe.parent.attempt_id_to_job_id(cl.attempt_id)) == (job1, tuple()) + res2 = await pipe.parent.attempt_id_to_job_id(cl.attempt_id, [model.jobs.c.depth]) + assert res2 == (job1, (None,)) + +@_test +async def test_compute_page_settings(): + rules = [ + db.JobRule(r"", {"accept": False}), + db.JobRule(r"^https?://hello\d\.very-good-quality-co\.de/", {"ua": "ua1"}), + db.JobRule(r"test", {"skip": True, "custom_js": "foo"}), + db.JobRule(r"aaa", {"custom_js": None, "accept": True}), + ] + + tests = ( + ("", db.PageSettings(accept = False)), + ("https://example.org", db.PageSettings(accept = False)), + ("https://hello4.very-good-quality-co.de/robots.txt", db.PageSettings(accept = False, ua = "ua1")), + ("https://hello6.very-good-quality-co.de/test", db.PageSettings(accept = False, ua = "ua1", skip = True, custom_js = "foo")), + ("http://example.com/aaa", db.PageSettings(accept = True)), + ("http://example.com/aaa/test", db.PageSettings(accept = True, skip = True)), + ) + for url, expected in tests: + result = db.Connection.compute_page_settings(rules, url) + assert result == expected + with pytest.raises(AssertionError): + result = db.Connection.compute_page_settings(rules, "") + assert result == db.PageSettings(accept = True) + +@_test +async def test_tag_rowcount(engine: sqlalchemy.ext.asyncio.AsyncEngine): + """ + Tests the return value in create_tags, remove_tags, and get_tags. + """ + async with engine.connect() as conn: + q = db.Connection(conn) + await q.create_pipeline("pipe", False, "password") + pipe = await q.pipeline("pipe") + async def create(expected, *tags): + # There doesn't currently seem to be an easy way to return this, so don't test it for now + await pipe.create_tags(*tags) + + async def remove(expected, *tags): + r = await pipe.remove_tags(*tags) + assert r == expected + + async def get(*expected): + r = await pipe.get_tags() + assert r == set(expected) + + await get() + await create(2, "foo", "bar") + await get("foo", "bar") + await remove(2, "foo", "bar") + await get() + await create(2, "foo", "bar", "foo") + await remove(1, "foo", "foo", "baz") + await get("bar") + await create(2, "foo", "bar", "baz") + await get("foo", "bar", "baz") + await remove(3, "foo", "bar", "baz") + await get() + +@_test +async def test_job_rule_insertion(engine: sqlalchemy.ext.asyncio.AsyncEngine): + """ + Tests create_job_rule and remove_job_rule. + """ + async with engine.connect() as conn: + q = db.Connection(conn) + await q.create_pipeline("pipe", False, "password") + pipe = await q.pipeline("pipe", 0) + job1 = await make_job(q) + + # Create our job rules + _, r1 = await q.create_job_rule(job1, None, db.JobRule("^https?://", {"ua": "minimal", "custom_js": "foo"})) + _, r2 = await q.create_job_rule(job1, None, db.JobRule("^https?://example.org$", {"ua": "stealth"})) + rs3, r3 = await q.create_job_rule(job1, r2, db.JobRule("https?://", {"custom_js": "bar"})) + # Test ensure_ruleset, both valid and invalid + rs4, r4 = await q.create_job_rule(job1, r1, db.JobRule("", db.PageSettings(skip = True).as_dict()), ensure_ruleset = rs3) + with pytest.raises(db.RulesetConflict): + await q.create_job_rule(job1, None, db.JobRule("", {"custom_js": "baz"}), ensure_ruleset = rs3) + rs, rules = await q.get_job_ruleset(job1) + assert rs == rs4, "Ruleset ID changed unexpectedly!" + assert [rule.scope for rule in rules] == ["", "^https?://", "https?://", "^https?://example.org$"] + + # Create some pages. No trailing slash, so r2 applies and ua will be stealth + await make_pages(q, job1, "http://example.org", "https://example.org") + # Ensure page settings are computed correctly + assert q.compute_page_settings(rules, "http://example.org") == db.PageSettings(ua = "stealth", custom_js = "bar", skip = True) + # Skip rule is in place, so no item should be dequeued + assert (await pipe._find_claimable_page(job1)) is None + + # No more skip rule... + await q.remove_job_rule(job1, r4) + rs, rules = await q.get_job_ruleset(job1) + # Ensure the correct one (i.e. the empty regex) was removed, and that computed page settings change accordingly + assert [rule.scope for rule in rules] == ["^https?://", "https?://", "^https?://example.org$"] + assert q.compute_page_settings(rules, "http://example.org") == db.PageSettings(ua = "stealth", custom_js = "bar", skip = False) + # Queue some more pages, now with trailing slash (to prevent unique conflict) + (page1, page2) = await make_pages(q, job1, "https://example.org/", "http://example.org/") + info = await pipe._find_claimable_page(job1) + assert info == db.PendingPage(page1, "https://example.org/", rs, db.PageSettings(ua = "minimal", custom_js = "bar")) + + rm1, oldval = await q.remove_job_rule(job1, 0) + assert oldval.scope == "^https?://" + rm2, oldval = await q.remove_job_rule(job1, 1, ensure_ruleset = rm1) + assert oldval.scope == "^https?://example.org$" + with pytest.raises(db.RulesetConflict): + await q.remove_job_rule(job1, 0, ensure_ruleset = rm1) + await q.remove_job_rule(job1, 0, ensure_ruleset = rm2) + rs, rules = await q.get_job_ruleset(job1) + assert rules == [] + + with pytest.raises(db.NoSuchThingError): + await q.create_job_rule(job1, 1000, db.JobRule("", {})) + with pytest.raises(db.NoSuchThingError): + await q.remove_job_rule(job1, 1000) + +@_test +async def test_job_rule_removal_by_scope(engine: sqlalchemy.ext.asyncio.AsyncEngine): + async with engine.connect() as conn: + q = db.Connection(conn) + job1 = await make_job(q) + + rs0, rules = await q.get_job_ruleset(job1) + assert rs0.int == 0 + assert rules == [] + + rs1, idx = await q.create_job_rule(job1, None, db.JobRule("foo", {}), ensure_ruleset = rs0) + assert idx == 0 + rs2, idx = await q.create_job_rule(job1, None, db.JobRule("bar", {}), ensure_ruleset = rs1) + assert idx == 1 + rs3, idx = await q.create_job_rule(job1, None, db.JobRule("foo", {})) + assert idx == 2 + rs4, idx = await q.create_job_rule(job1, None, db.JobRule("foo", {})) + assert idx == 3 + + with pytest.raises(db.RulesetConflict): + await q.create_job_rule(job1, None, db.JobRule("foo", {}), rs0) + + rs4_, rules = await q.get_job_ruleset(job1) + assert rs4 == rs4_ + assert rules == [db.JobRule("foo", {}), db.JobRule("bar", {}), db.JobRule("foo", {}), db.JobRule("foo", {})] + + rs5, num_removed = await q.remove_job_rules_by_scope(job1, "foo", rs4) + assert num_removed == 3 + + rs5_, rules = await q.get_job_ruleset(job1) + assert rs5 == rs5_ + assert rules == [db.JobRule("bar", {})] + + with pytest.raises(db.RulesetConflict): + await q.remove_job_rules_by_scope(job1, "foo", rs4) + + rs5__, num_removed = await q.remove_job_rules_by_scope(job1, "foo", rs5) + assert num_removed == 0 + assert rs5__ == rs5 + + rs6, num_removed = await q.remove_job_rules_by_scope(job1, "bar", rs5) + assert num_removed == 1 + rs6_, num_removed = await q.remove_job_rules_by_scope(job1, "bar", rs6) + assert num_removed == 0 + assert rs6_ == rs6 + + rs6__, rules = await q.get_job_ruleset(job1) + assert rs6__ == rs6 + assert rules == [] + +@_test +async def test_job_rule_edge_cases(engine: sqlalchemy.ext.asyncio.AsyncEngine): + """ + Tests some possible edge cases related to job rulesets. + """ + async with engine.connect() as conn: + q = db.Connection(conn) + job1 = await make_job(q) + + id1, _ = await q.get_job_ruleset(job1) + assert id1 == model.UUID(int = 0) + + # This ruleset is empty! Deletion shouldn't work. + with pytest.raises(db.NoSuchThingError): + await q.remove_job_rule(job1, 0, id1) + + # Negative indexing probably shouldn't be allowed. + with pytest.raises(db.NoSuchThingError): + await q.create_job_rule(job1, -1, db.JobRule("", {})) + # Inserting at len(rules) to append is not supported. + with pytest.raises(db.NoSuchThingError): + await q.create_job_rule(job1, 0, db.JobRule("", {})) + # No writes have actually gone through, so the ID should remain unchanged. + id2, _ = await q.get_job_ruleset(job1) + assert id1 == id2 + await q.create_job_rule(job1, None, db.JobRule("", {})) + # And check negative indexing when there *is* a rule, too. + with pytest.raises(db.NoSuchThingError): + await q.create_job_rule(job1, -1, db.JobRule("", {})) + with pytest.raises(db.NoSuchThingError): + await q.remove_job_rule(job1, -1) + # Removing len(rules) should definitely not work. + with pytest.raises(db.NoSuchThingError): + await q.remove_job_rule(job1, 1) + + # Ensure that conflicts do not change the ID either. + id3, _ = await q.get_job_ruleset(job1) + with pytest.raises(db.RulesetConflict): + await q.create_job_rule(job1, None, db.JobRule("", {}), id1) + id4, _ = await q.get_job_ruleset(job1) + assert id3 == id4 + +@_test +async def test_ruleset_removal_order(engine: sqlalchemy.ext.asyncio.AsyncEngine): + """ + Tests that removing rules preserves the order of the other ones. + Kind of tested above, but this is a little more explicit. + """ + async with engine.connect() as conn: + q = db.Connection(conn) + job1 = await make_job(q, concurrency = 1) + + for scope in ("foo", "removeme", "bar", "removeme", "baz", "quux"): + await q.create_job_rule(job1, None, db.JobRule(scope, {})) + await q.remove_job_rule(job1, 0) + await q.remove_job_rule(job1, 4) + _, rules = await q.get_job_ruleset(job1) + assert [rule.scope for rule in rules] == ["removeme", "bar", "removeme", "baz"] + await q.remove_job_rules_by_scope(job1, "removeme") + _, rules = await q.get_job_ruleset(job1) + assert [rule.scope for rule in rules] == ["bar", "baz"] + +@_test +async def test_many_skipped_jobs(engine: sqlalchemy.ext.asyncio.AsyncEngine): + async with engine.connect() as conn: + q = db.Connection(conn) + job1 = await make_job(q, concurrency = 1) + await q.create_job_rule(job1, None, db.JobRule("^skipped_", {"skip": True, "custom_js": "foo"})) + await make_pages(q, job1, *[f"skipped_{i}" for i in range(100)], "hello") + await q.create_job_rule(job1, None, db.JobRule("", {"ua": "stealth", "accept": True})) + await q.create_pipeline("pipe", False, "password") + pipe = await q.pipeline("pipe", 0) + + claim = await pipe.find_claim_page("", 0, model.JobType.BROZZLER) + assert claim + assert claim.payload == "hello" + assert claim.settings == db.PageSettings(ua = "stealth", accept = True) + with pytest.raises(db.JobExhausted): + await pipe.find_claim_page("", 0, model.JobType.BROZZLER) + +@db._wrap_serialization_failure +async def _commit(conn): + await conn.commit() + +@db._wrap_serialization_failure +async def _execute(conn, q): + await conn.execute(q) + +@_test +async def test_serializable_wrapper(engine: sqlalchemy.ext.asyncio.AsyncEngine): + async with engine.connect() as conn: + await conn.execute(sqlalchemy.insert(model.options).values(key = "hi", value = "bye")) + await conn.execute(sqlalchemy.insert(model.options).values(key = "bye", value = "hi")) + await conn.commit() + async with engine.connect() as conn1: + async with engine.connect() as conn2: + q1 = sqlalchemy.select(model.options.c.value).where(model.options.c.key == "hi") + q2 = sqlalchemy.update(model.options).where(model.options.c.value == "bye").values(value = "cye") + q3 = sqlalchemy.update(model.options).where(model.options.c.value == "hi").values(value = "dye") + await conn1.execute(q1) + await conn2.execute(q2) + await conn1.execute(q3) + await conn2.commit() + with pytest.raises(db.SerializationFailure): + await _commit(conn1) + await conn1.rollback() + with pytest.raises(sqlalchemy.exc.IntegrityError): + q1 = sqlalchemy.insert(model.options).values(key = "hi", value = "eye") + await _execute(conn1, q1) + + +# Test result duplication checking +# Test multiple of the same payload in create_page + # And the niceness update. +# Test nonexistence errors. diff --git a/tracker/common/db.py b/tracker/common/db.py new file mode 100644 index 0000000..c6ea202 --- /dev/null +++ b/tracker/common/db.py @@ -0,0 +1,1039 @@ +""" +Various database functions, put here for ease of reuse (and testing). + +The functions in this file that take a connection parameter do not automatically commit. +Transaction management is left to the caller. +""" + +import dataclasses +import typing +import os +import regex + +import uuid_utils.compat as uuid +import uuid_utils +import argon2 +hasher = argon2.PasswordHasher(time_cost = 2, memory_cost = 47104, parallelism = 1) # slightly higher than OWASP cheat sheet + +import sqlalchemy, sqlalchemy.ext.asyncio, sqlalchemy.exc, sqlalchemy.dialects.postgresql +from sqlalchemy.ext.asyncio import AsyncConnection + +import asyncpg + +from . import model + +class InvalidQueue(Exception): pass + +class NoSuchThingError(Exception): pass +class InvalidIdError(NoSuchThingError): pass +class NoSuchPipelineError(NoSuchThingError): pass + +class AlreadyExistsError(Exception): pass +class AuthenticationFailure(Exception): pass +class RulesetConflict(Exception): pass + +class JobExhausted(Exception): + """ + A job was claimed, but no page could be found. The status should be rechecked. + """ + job_id: model.UUID + def __init__(self, job_id: model.UUID): + self.job_id = job_id + +class SerializationFailure(Exception): + """ + The database returned a serialization failure and the transaction must be rolled back. + The operation may succeed if retried. + """ + +def _wrap_serialization_failure(f): + """ + Decorator that converts serialization exceptions to SerializationFailure. + """ + async def newf(*args, **kwargs): + try: + return await f(*args, **kwargs) + except sqlalchemy.exc.DBAPIError as e: + if e.orig and isinstance(e.orig.__cause__, asyncpg.exceptions.SerializationError): + raise SerializationFailure() + raise + return newf + +@_wrap_serialization_failure +async def _execute(conn, query): + return await conn.execute(query) + +async def create_engine(uri: str | None = None, check_version = True) -> sqlalchemy.ext.asyncio.AsyncEngine: + """ + Creates a handle to the database. + + If uri is None, attempts to read the environment variable MNBOT_DATABASE_URI. + check_version should only be set to False in database creation or migration. + """ + if uri is None: + uri = os.environ['MNBOT_DATABASE_URI'] + + engine = sqlalchemy.ext.asyncio.create_async_engine(uri, isolation_level = "SERIALIZABLE") + async with engine.connect() as conn: + # check_version is never set to false, so this is just a cheeky if + while True + query = sqlalchemy.select(model.options.c.value).where(model.options.c.key == "version") + while check_version: + try: + result = (await _execute(conn, query)).first() + if not result: + raise InvalidQueue("No version key") + assert result[0] == str(model.SCHEMA_VERSION) + except sqlalchemy.exc.ProgrammingError as e: + raise InvalidQueue(e.orig) + except SerializationFailure: + # Try again on serialization failure + continue + else: + break + return engine + +def generate_id() -> model.UUID: + return uuid.uuid7() + +def parse_id(id: str) -> model.UUID: + try: + return model.UUID(id) + except ValueError: + raise InvalidIdError() + +def parse_id_ex(id: str) -> uuid_utils.UUID: + try: + return uuid_utils.UUID(id) + except ValueError: + raise InvalidIdError() + +def _debug_compile_statement(stmt, engine): + return stmt.compile(dialect = engine.dialect) + +@dataclasses.dataclass +class JobCreation: + job_id: model.UUID + type: model.JobType + created_by: str + metadata: dict + initial_page: str + + status: model.JobStatus = model.JobStatus.ACTIVE + concurrency: int = 0 + nice: int = 0 + tag: str | None = None + note: str | None = None + depth: int | None = None + +@dataclasses.dataclass +class PageCreation: + page_id: model.UUID + payload: str + parent_page: model.UUID | None + + nice: int = 0 + status: model.PageStatus = model.PageStatus.READY + +@dataclasses.dataclass +class Counts: + todo_jobs: int + fully_claimed_jobs: int + todo_pages: int + claimed_pages: int + +@dataclasses.dataclass +class PipelineInfo: + matchonly: bool + current_claim: tuple[model.UUID | None, model.ClaimLock | None] + +@dataclasses.dataclass +class PendingPage: + page_id: model.UUID + payload: str + ruleset_id: model.UUID + page_settings: "PageSettings" + +def _find_claimable_job_q(pipeline: str | None, type: model.JobType, matchonly: bool, include_existing: model.UUID | None, limit: bool): + """ + Creates a query for find_claimable_job. + If pipeline is null, remove all filters on tag. This is useful when displaying the queue. + """ + if pipeline is not None: + tag_criteria = sqlalchemy.exists().where((model.tags.c.pipeline_id == pipeline) & (model.jobs.c.tag == model.tags.c.tag)) + if matchonly: + tag_where_clause = (model.jobs.c.tag != None) & tag_criteria + else: + tag_where_clause = (model.jobs.c.tag == None) | tag_criteria + else: + tag_where_clause = sqlalchemy.true() + + concurrency_where = model.jobs.c.concurrency > model.jobs.c.active_claims + if include_existing: + concurrency_where |= ((model.jobs.c.job_id == include_existing) & (model.jobs.c.concurrency == model.jobs.c.active_claims)) + + q = ( + sqlalchemy.select(model.jobs.c.job_id) + .where(model.jobs.c.type == type) + .where( model.jobs.c.status == model.JobStatus.ACTIVE) + .where(concurrency_where) + .where(tag_where_clause) + .order_by(*model.jobs_dequeue_order) + #.with_for_update(key_share = True) + ) + if limit: + q = q.limit(1) + return q + +@dataclasses.dataclass +class PageSettings: + ua: str = "default" + custom_js: str | None = None + skip: bool = False + accept: bool = False + + def __or__(self, value): + if not isinstance(value, dict): + return NotImplemented + + return PageSettings( + ua = value.get("ua", self.ua), + custom_js = value.get("custom_js", self.custom_js), + skip = value.get("skip", self.skip), + accept = value.get("accept", self.accept), + ) + + def as_dict(self): + return dataclasses.asdict(self) + +@dataclasses.dataclass +class PageClaimInfo: + page_id: model.UUID + job_id: model.UUID + attempt_id: model.UUID + payload: str + settings: PageSettings + + def as_json_friendly_dict(self): + return { + "page_id": str(self.page_id), + "job_id": str(self.job_id), + "attempt_id": str(self.attempt_id), + "payload": self.payload, + "settings": dataclasses.asdict(self.settings), + } + +@dataclasses.dataclass +class JobRule: + scope: str + payload: dict + + def __post_init__(self): + self._compiled_scope = regex.compile(self.scope) + + @property + def compiled_scope(self): + return self._compiled_scope + + def for_db(self): + return (self.scope, self.payload) + +class Connection: + """ + A wrapper around a database connection, with utility methods. + + The AsyncConnection can freely be used on its own alongside this object. + Note: The connection is not designed to survive an error that is not + mentioned in the docstring. If a method raises an undocumented exception, + the transaction should be rolled back. + """ + def __init__(self, conn: AsyncConnection): + self.conn = conn + + def id(self): + """ + Generates a UUIDv7 for use as a primary key. + """ + return generate_id() + + @_wrap_serialization_failure + async def create_jobs(self, jobs: typing.Iterable[JobCreation]): + """ + Adds jobs to the database. + """ + values = [] + for job in jobs: + values.append(dict( + job_id = job.job_id, + type = job.type, + status = job.status, + concurrency = job.concurrency, + nice = job.nice, + tag = job.tag, + created_by = job.created_by, + note = job.note, + initial_page = job.initial_page, + metadata = job.metadata, + depth = job.depth, + )) + await self.conn.execute(sqlalchemy.insert(model.jobs), values) + + @_wrap_serialization_failure + async def create_pages(self, job_id: model.UUID, pages: typing.Iterable[PageCreation]) -> dict[model.UUID, model.UUID]: + """ + Adds pages to the database for a particular job. + Existing pages will be ignored, but the niceness value may be overwritten if the new one is lower. + + Returns a mapping of IDs passed to IDs actually inserted (or existing IDs). + + Important note: If a page with the same payload is passed twice, the niceness from the first entry will always be used. + Currently, with how this method is used, this is not an issue (every page has the same niceness). But that may change. + """ + values = [] + relation_values = {} + # Because payloads are unique, we RETURN the payload to associate it with the caller-given ID. + # Sentinel columns or postgres' future feature of EXCLUDED clause in RETURNING won't work + # because we want to include deduplicated rows in the return value. + payload_to_id_mapping: dict[str, list[model.UUID]] = {} + num_retries = sqlalchemy.cast( + sqlalchemy.select(model.options.c.value).where(model.options.c.key == "tries").scalar_subquery(), + sqlalchemy.SmallInteger + ) + q = sqlalchemy.dialects.postgresql.insert(model.pages).values(attempts_remaining = num_retries) + q = q.on_conflict_do_update(constraint = model.pages_index_unique, set_ = dict( + nice = sqlalchemy.func.least(q.excluded.nice, model.pages.c.nice), + )) + q = q.returning(sqlalchemy.text("old.page_id"), sqlalchemy.text("page_id"), model.pages.c.payload) + for page in pages: + if page.payload not in payload_to_id_mapping: + payload_to_id_mapping[page.payload] = [] + # Only insert it if it's never been seen before. (Otherwise, postgres gets sad.) + values.append(dict( + page_id = page.page_id, + job_id = job_id, + payload = page.payload, + # TODO: Use the lowest one of these, if multiple are supplied to the function. + nice = page.nice, + status = page.status, + )) + payload_to_id_mapping[page.payload].append(page.page_id) + relation_values[page.page_id] = dict( + page_id = page.page_id, + job_id = job_id, + parent_page = page.parent_page, + # Depth is overwritten by the INSERT trigger. + depth = -1, + ) + res = await self.conn.execute(q, values) + # In case there are existing items, give the actual page IDs to the caller + id_mapping: dict[model.UUID, model.UUID] = {} + for existing_id, new_id, payload in res.all(): + # If Postgres gave an existing id, use that. Otherwise, use the id we gave. + existing_id = existing_id or new_id + caller_given_ids = payload_to_id_mapping[payload] + # If the caller passes multiple page objects with the same id, handle them all + for caller_given_id in caller_given_ids: + # Associate the caller-given ID in the return value with the existing ID + id_mapping[caller_given_id] = existing_id + # Update the relation for that id with the real id + relation_values[caller_given_id]['page_id'] = existing_id + + # Add relations + relation_q = sqlalchemy.insert(model.relations) + await self.conn.execute(relation_q, list(relation_values.values())) + return id_mapping + + @_wrap_serialization_failure + async def get_job_counts(self) -> dict[model.JobType, int]: + """ + Gets the number of active/draining jobs for each job type. + """ + q = ( + sqlalchemy.select(model.jobs.c.type, sqlalchemy.func.count()) + .select_from(model.jobs) + .group_by(model.jobs.c.type) + .where( model.jobs.c.status.in_((model.JobStatus.ACTIVE, model.JobStatus.DRAINING))) + ) + cursor = await self.conn.execute(q) + res = {} + for row in cursor: + res[row[0]] = row[1] + return res + + @_wrap_serialization_failure + async def get_all_claimable_jobs(self, type: model.JobType) -> typing.Sequence[model.UUID]: + """ + Gets all claimable jobs in order, with no tag restrictions. + """ + q = _find_claimable_job_q(None, type, False, None, False) + cursor = await self.conn.execute(q) + return [row[0] for row in cursor] + + async def _prepare_slot(self, pipeline: str, *slots): + """ + Prepares the claim entry for the slots given. + """ + q = sqlalchemy.dialects.postgresql.insert(model.claims).on_conflict_do_nothing() + values = [] + for slot in slots: + values.append(dict( + pipeline_id = pipeline, + slot = slot, + job_id = None, + lock = None, + )) + await self.conn.execute(q, values) + + async def _ensure_pipeline(self, pipeline_id: str): + """ + Ensures that a pipeline exists. + """ + q = sqlalchemy.select(model.pipelines.c.pipeline_id).where(model.pipelines.c.pipeline_id == pipeline_id) + res = await self.conn.execute(q) + if not res.first(): + raise NoSuchPipelineError() + + @_wrap_serialization_failure + async def pipeline(self, pipeline_id: str, *slots_to_prepare): + """ + Returns a Pipeline object with the current connection, and prepares the given slots (if any). + """ + if slots_to_prepare: + await self._prepare_slot(pipeline_id, *slots_to_prepare) + else: + await self._ensure_pipeline(pipeline_id, *slots_to_prepare) + return Pipeline(self, pipeline_id) + + @_wrap_serialization_failure + async def create_pipeline(self, pipeline: str, matchonly: bool, password: str): + """ + Creates a pipeline. + """ + hash = hasher.hash(password) + q = model.pipelines.insert() + value = dict( + pipeline_id = pipeline, + pipeline_secret = hash, + matchonly = matchonly, + ) + await self.conn.execute(q, [value]) + return Pipeline(self, pipeline) + + @_wrap_serialization_failure + async def retry_page(self, page_id: model.UUID, max_tries: int): + """ + Sets a page's max tries to max_tries, and resets the attempt counter to zero. + (Existing attempt rows are not changed, only the attempts column.) + """ + q = ( + sqlalchemy.update(model.pages) + .where(model.pages.c.page_id == page_id) + .values(attempts_remaining = max_tries, attempts = 0) + ) + res = await self.conn.execute(q) + if res.rowcount == 0: + raise NoSuchThingError + + @_wrap_serialization_failure + async def update_job_status(self, job_id: model.UUID, allow_resumption: bool = False) -> model.JobStatus: + """ + Updates the job status according to the pages left in the queue. + If allow_resumption is False and the job status is ABORTED, the job will remain ABORTED. + + Note: This method may mark a job as ACTIVE when it really should be marked as DONE, + as for performance reasons it does not take SKIP rules into account. + However, the status will be fixed as soon as a pipeline attempts to claim an item. + + Returns the new status. + """ + # Absolute chonker of a query + q = ( + sqlalchemy.update(model.jobs) + .values( + status = sqlalchemy.case( + # If the job is aborted _and_ allow_resumption is False, do nothing. + ( + (model.jobs.c.status == model.JobStatus.ABORTED) & (not allow_resumption), + model.jobs.c.status + ), + # If the job has no available pages... + ( + ~sqlalchemy.exists(self._all_pending_pages_q(job_id)), + sqlalchemy.case( + # ... set status to DRAINING or DONE based on the number of claimed pages. + ( + sqlalchemy.exists( + sqlalchemy.select(model.pages) + .where(model.pages.c.job_id == model.jobs.c.job_id) + .where(model.pages.c.status == model.PageStatus.CLAIMED) + ), + sqlalchemy.text("'DRAINING'::jobstatus"), + ), + else_ = sqlalchemy.text("'DONE'::jobstatus") + ) + ), + # Otherwise, the job is active. + else_ = sqlalchemy.text("'ACTIVE'::jobstatus"), + ) + ) + .where(model.jobs.c.job_id == job_id) + .returning(model.jobs.c.status) + ) + res = await self.conn.execute(q) + row = res.first() + if not row: + raise NoSuchThingError + + new_status = row[0] + if new_status in (model.JobStatus.DONE, model.JobStatus.ABORTED): + # Disclaim this job from all pipelines + q = ( + sqlalchemy.update(model.claims) + .where(model.claims.c.job_id == job_id) + .where( + (model.claims.c.lock == None) | (model.claims.c.lock == model.ClaimLock.UNTIL_FINISHED) + ) + .values(job_id = None, lock = None) + ) + await self.conn.execute(q) + + return new_status + + @_wrap_serialization_failure + async def abort_job(self, job_id: model.UUID): + """ + Sets the status of the given job to ABORTED. + """ + q = sqlalchemy.update(model.jobs).where(model.jobs.c.job_id == job_id).values(status = model.JobStatus.ABORTED) + res = await self.conn.execute(q) + if not res.rowcount: + raise NoSuchThingError(job_id) + + @_wrap_serialization_failure + async def is_single_job(self, job_id: model.UUID) -> bool: + """ + Returns True if the job has exactly one page, and False otherwise. + """ + q = ( + sqlalchemy.select(sqlalchemy.func.count()) + .select_from(model.pages) + .where(model.pages.c.job_id == job_id) + .limit(2) + ) + res = await self.conn.execute(q) + row = res.first() + assert row + return row[0] == 1 + + @_wrap_serialization_failure + async def attempt_id_to_job_id(self, attempt_id: model.UUID, additional_fields: list | None = None, lock_job: bool = True) -> tuple[model.UUID, typing.Iterable[typing.Any]]: + """ + Returns the job ID for a given attempt ID, optionally returning additional columns. + If lock_job is True, a FOR UPDATE lock will be taken on the job row. + """ + if not additional_fields: + additional_fields = [] + q = ( + sqlalchemy.select(model.jobs.c.job_id, *additional_fields) + .where(model.attempts.c.attempt_id == attempt_id) + .join(model.pages, model.pages.c.page_id == model.attempts.c.page_id) + .join(model.jobs, model.jobs.c.job_id == model.pages.c.job_id) + ) + if lock_job: + q = q.with_for_update(key_share = True, of = model.jobs) + res = await self.conn.execute(q) + row = res.first() + if not row: + raise NoSuchThingError + return row[0], row[1:] + + @_wrap_serialization_failure + async def get_ruleset(self, job_ruleset_id: model.UUID) -> list[JobRule]: + q = ( + sqlalchemy.select(model.job_rulesets.c.rules) + .where(model.job_rulesets.c.job_ruleset_id == job_ruleset_id) + ) + res = await self.conn.scalar(q) + if res is None: + raise NoSuchThingError(job_ruleset_id) + rules = [JobRule(rule[0], rule[1]) for rule in res] + return rules + + @_wrap_serialization_failure + async def get_job_ruleset(self, job_id: model.UUID, for_update = False) -> tuple[model.UUID, list[JobRule]]: + q = ( + sqlalchemy.select(model.job_rulesets.c.job_ruleset_id, model.job_rulesets.c.rules) + .where(model.job_rulesets.c.job_id == job_id) + .order_by(model.job_rulesets.c.job_ruleset_id.desc()) + .limit(1) + ) + if for_update: + q = q.with_for_update(key_share = True) + res = await self.conn.execute(q) + row = res.one_or_none() + if row is None: + return uuid.UUID(int = 0), [] + ruleset_id, raw_rules = row + rules = [JobRule(rule[0], rule[1]) for rule in raw_rules] + return ruleset_id, rules + + @_wrap_serialization_failure + async def new_ruleset(self, job_id: model.UUID, ruleset_id: model.UUID, rules: list[JobRule]): + """ + Sets a job's ruleset, overriding previous ones. + Throws RulesetConflict if a newer ruleset ID exists, which could be caused by a bad system clock. + """ + # Ensure newer ruleset ID does not exist + # Serializable isolation should prevent any race condition here + q = ( + sqlalchemy.select(model.job_rulesets.c.job_ruleset_id) + .where(model.job_rulesets.c.job_id == job_id) + .where(model.job_rulesets.c.job_ruleset_id >= ruleset_id) + .limit(1) + ) + row = (await self.conn.execute(q)).one_or_none() + if row: + raise RulesetConflict(row[0]) + + # Insert new ruleset + q = sqlalchemy.insert(model.job_rulesets) + raw_rules = [rule.for_db() for rule in rules] + await self.conn.execute(q, dict(job_ruleset_id = ruleset_id, job_id = job_id, rules = raw_rules)) + + @_wrap_serialization_failure + async def create_job_rule(self, job_id: model.UUID, before_rule: int | None, rule: JobRule, ensure_ruleset: model.UUID | None = None) -> tuple[model.UUID, int]: + """ + Creates a job rule before the given rule position, or None to append to the end. + + If ensure_ruleset is not None, and the currently-active ruleset does *not* have that ID, + RulesetConflict will be raised. + + Returns the new ruleset ID and the new rule index. + """ + old_ruleset_id, rules = await self.get_job_ruleset(job_id, for_update = True) + if ensure_ruleset is not None: + if ensure_ruleset != old_ruleset_id: + raise RulesetConflict(old_ruleset_id) + if before_rule is None: + rules.append(rule) + position = len(rules) - 1 + else: + if before_rule >= len(rules) or before_rule < 0: + raise NoSuchThingError(before_rule) + rules.insert(before_rule, rule) + position = before_rule + new_id = generate_id() + await self.new_ruleset(job_id, new_id, rules) + return new_id, position + + @_wrap_serialization_failure + async def remove_job_rule(self, job_id: model.UUID, index: int, ensure_ruleset: model.UUID | None = None) -> tuple[model.UUID, JobRule]: + """ + Removes a given job rule. If the rule does not exist, raises NoSuchThingError. + + Raises RulesetConflict if the ensure_ruleset check fails. + Returns the new ruleset ID and the old rule value. + """ + old_ruleset_id, rules = await self.get_job_ruleset(job_id, for_update = True) + if ensure_ruleset is not None: + if ensure_ruleset != old_ruleset_id: + raise RulesetConflict(old_ruleset_id) + if index < 0: + raise NoSuchThingError(index) + try: + old_val = rules.pop(index) + except IndexError: + raise NoSuchThingError(index) + new_id = generate_id() + await self.new_ruleset(job_id, new_id, rules) + return new_id, old_val + + + @_wrap_serialization_failure + async def remove_job_rules_by_scope(self, job_id: model.UUID, scope: str, ensure_ruleset: model.UUID | None = None) -> tuple[model.UUID, int]: + """ + Remove all job rules with the given scope. + Returns the new ruleset ID and the number of rules removed. + (If no values are removed, it will return the existing ruleset ID instead.) + + Raises RulesetConflict if the ensure_ruleset check fails. + """ + old_ruleset_id, old_rules = await self.get_job_ruleset(job_id, for_update = True) + values_removed = 0 + if ensure_ruleset is not None: + if ensure_ruleset != old_ruleset_id: + raise RulesetConflict(old_ruleset_id) + new_rules = [] + for rule in old_rules: + if rule.scope == scope: + values_removed += 1 + continue + new_rules.append(rule) + if not values_removed: + return old_ruleset_id, 0 + new_id = generate_id() + await self.new_ruleset(job_id, new_id, new_rules) + return new_id, values_removed + + @staticmethod + def compute_page_settings(rules: list[JobRule], url: str) -> PageSettings: + settings = PageSettings() + for rule in rules: + if rule.compiled_scope.search(url, timeout = 15): + settings |= rule.payload + return settings + + @classmethod + def _page_depth(cls, page_id): + return ( + sqlalchemy.select(sqlalchemy.func.min(model.relations.c.depth)) + .where(model.relations.c.page_id == page_id) + .scalar_subquery() + ) + + @classmethod + def _job_depth(cls, job_id: model.UUID): + return ( + sqlalchemy.select(model.jobs.c.depth) + .where(model.jobs.c.job_id == job_id) + .scalar_subquery() + ) + + @classmethod + def _all_pending_pages_q(cls, job_id: model.UUID): + return ( + sqlalchemy.select(model.pages.c.page_id, model.pages.c.payload) + .where(model.pages.c.job_id == job_id) + .where(model.pages_dequeue_filter) + .where( + (cls._page_depth(model.pages.c.page_id) <= cls._job_depth(job_id)) + | (cls._job_depth(job_id) == None) + ) + .order_by(*model.pages_dequeue_order) + ) + + async def all_pending_pages(self, job_id: model.UUID, *, _page_id: model.UUID | None = None): + """ + Returns all pending pages for a job. Pages with a skip setting active that + have not yet been set to the SKIPPED status are included and will need to + be filtered out by the caller. + + If _page_id is not None, only that page ID will be returned. This is useful for tests. + + Yields a PendingPage object for every page. + """ + ruleset_id, rules = await self.get_job_ruleset(job_id) + select_query = self._all_pending_pages_q(job_id) + if _page_id is not None: + select_query = select_query.where(model.pages.c.page_id == _page_id) + + async with self.conn.stream(select_query) as stream: + async for row in stream: + page_id, payload = row + settings = self.compute_page_settings(rules, payload) + yield PendingPage(page_id, payload, ruleset_id, settings) + +class Pipeline: + """ + Note: The connection is not designed to survive an error that is not + mentioned in its docstring. If a method raises an undocumented exception, + the transaction should be rolled back. + """ + def __init__(self, parent: Connection, pipeline_id: str): + """ + This constructor is not public API + """ + self.parent = parent + self.conn = parent.conn + self.pipeline_id = pipeline_id + + @_wrap_serialization_failure + async def info(self, slot: int) -> PipelineInfo: + """ + Gets pipeline information and the current claim for a pipeline. + """ + q = ( + sqlalchemy.select(model.pipelines.c.matchonly, model.claims.c.job_id, model.claims.c.lock) + .join(model.claims, (model.pipelines.c.pipeline_id == model.claims.c.pipeline_id) & (model.claims.c.slot == slot), isouter = False) + .where(model.pipelines.c.pipeline_id == self.pipeline_id) + ) + #if lock_claim: + # q = q.with_for_update(of = model.claims, key_share = True) + result = await self.conn.execute(q) + data = result.first() + if not data: + raise NoSuchPipelineError("Pipeline slot should be registered before use") + matchonly, job_id, lock = data + return PipelineInfo(matchonly = matchonly, current_claim = (job_id, lock)) + + async def _find_claimable_job(self, type: model.JobType, matchonly: bool, include_existing: model.UUID | None) -> model.UUID | None: + """ + Finds and returns a job to claim for a particular pipeline. + This doesn't actually claim the job, only selects one. + + include_existing is the existing claimed job. If this is specified, that job can have reached its concurrency + limit if it is still the first in the queue order and its tag and type still apply. + However, if tags etc don't match (e.g. they were changed), another job (or none!) will still be picked. + + If no suitable job can be found, including include_existing, returns None. + """ + q = _find_claimable_job_q(self.pipeline_id, type, matchonly, include_existing, True) + cursor = await self.conn.execute(q) + res = cursor.first() + if res: + return res[0] + return None + + @_wrap_serialization_failure + async def _set_claim(self, slot: int, job: model.UUID | None): + """ + Sets the current claim for a job, and updates the claim count of the old and new jobs. + Removes any claim lock if present. + """ + q = ( + sqlalchemy.update(model.claims) + .where(model.claims.c.pipeline_id == self.pipeline_id) + .where(model.claims.c.slot == slot) + .values(job_id = job, lock = None) + .returning(sqlalchemy.text("old.job_id")) + ) + res = (await self.conn.execute(q)).first() + if not res: + raise NoSuchPipelineError() + old_job = res[0] + if old_job == job: + # no need to change active claims, nothing has changed + return + # TODO: Combine these into one query + if old_job: + q = ( + sqlalchemy.update(model.jobs) + .where(model.jobs.c.job_id == old_job) + .values(active_claims = model.jobs.c.active_claims - 1) + ) + await self.conn.execute(q) + if job: + q = ( + sqlalchemy.update(model.jobs) + .where(model.jobs.c.job_id == job) + .values(active_claims = model.jobs.c.active_claims + 1) + ) + await self.conn.execute(q) + + @_wrap_serialization_failure + async def _create_attempt(self, page: model.UUID, pipeline_version: str, ruleset_id: model.UUID) -> model.UUID: + """ + Creates an attempt for a page, returning its ID. + """ + ident = uuid.uuid7() + q = sqlalchemy.insert(model.attempts) + val = dict( + attempt_id = ident, + page_id = page, + pipeline_id = self.pipeline_id, + pipeline_version = pipeline_version, + ruleset_id = ruleset_id, + ) + await self.conn.execute(q, (val,)) + return ident + + @_wrap_serialization_failure + async def _find_claimable_page(self, job_id: model.UUID, *, _page_id: model.UUID | None = None) -> PendingPage | None: + """ + Finds a page to claim for a particular job, returning None if nothing was found. + + If _page_id is not None, either that specific page will be returned (if eligible), or nothing. + This is used during unit testing. + + Returns a tuple of (page_id, payload, ruleset_id, page_settings). + """ + to_skip = [] + async for page in self.parent.all_pending_pages(job_id, _page_id = _page_id): + if page.page_settings.skip: + to_skip.append(page.page_id) + else: + return page + + if to_skip: + update_query = ( + sqlalchemy.update(model.pages) + .values(status = model.PageStatus.SKIPPED) + .where(model.pages.c.page_id.in_(to_skip)) + ) + await self.conn.execute(update_query) + + @_wrap_serialization_failure + async def _claim_page(self, job_id: model.UUID, pipeline_version: str) -> PageClaimInfo | None: + """ + Claims a page from a particular job. + """ + info = await self._find_claimable_page(job_id) + if info is None: + raise JobExhausted(job_id) + q = ( + sqlalchemy.update(model.pages) + .where(model.pages.c.page_id == info.page_id) + .values( + status = model.PageStatus.CLAIMED, + attempts = model.pages.c.attempts + 1, + attempts_remaining = model.pages.c.attempts_remaining - 1, + ) + ) + await self.conn.execute(q) + attempt = await self._create_attempt(info.page_id, pipeline_version, info.ruleset_id) + return PageClaimInfo( + page_id = info.page_id, + attempt_id = attempt, + job_id = job_id, + payload = info.payload, + settings = info.page_settings, + ) + + @_wrap_serialization_failure + async def find_claim_page(self, pipeline_version: str, slot: int, type: model.JobType) -> PageClaimInfo | None: + """ + Claims a page. + + Specifically, this function: + 1. Gets the current claim. + 2. If the current claim is not locked, finds a new claim (if a more-appealing one exists) and disclaims the old one. + 3. Gets a page from whichever job is now claimed. + + Returns a PageClaimInfo object if a page was found; otherwise returns None. + """ + pipeline_info = await self.info(slot) + assert pipeline_info is not None, "Pipeline disappeared" + current_claim, current_lock = pipeline_info.current_claim + if current_lock is None: + # No lock exists, try to find a better job + new_job = await self._find_claimable_job(type, pipeline_info.matchonly, current_claim) + if new_job and new_job != pipeline_info.current_claim[0]: + # Found a better option, let's claim that + await self._set_claim(slot, new_job) + current_claim = new_job + if current_claim: + return await self._claim_page(current_claim, pipeline_version) + + @_wrap_serialization_failure + async def create_tags(self, *tags): + """ + Assigns tags to a pipeline. Tags that already exist are ignored. + """ + q = ( + sqlalchemy.dialects.postgresql.insert(model.tags) + .on_conflict_do_nothing() + ) + values = [] + for tag in tags: + values.append(dict(pipeline_id = self.pipeline_id, tag = tag)) + await self.conn.execute(q, values) + + @_wrap_serialization_failure + async def remove_tags(self, *tags) -> int: + """ + Removes tags from a pipeline. Tags that don't exist are ignored. + + Returns the number of tags actually removed. + """ + q = model.tags.delete().where(model.tags.c.pipeline_id == self.pipeline_id).where(model.tags.c.tag.in_(tags)) + res = await self.conn.execute(q) + return res.rowcount + + @_wrap_serialization_failure + async def get_tags(self) -> set[str]: + """ + Returns the tags associated with a pipeline. + """ + q = sqlalchemy.select(model.tags.c.tag).where(model.tags.c.pipeline_id == self.pipeline_id) + res = await self.conn.execute(q) + return {row[0] for row in res.all()} + + @_wrap_serialization_failure + async def authenticate(self, password: str): + """ + Authenticates the pipeline, raising AuthenticationFailure if the pipeline doesn't exist or the password doesn't match. + """ + q = sqlalchemy.select(model.pipelines.c.pipeline_secret).where(model.pipelines.c.pipeline_id == self.pipeline_id) + result = await self.conn.execute(q) + hash = result.first() + if not hash: + raise AuthenticationFailure + try: + assert hasher.verify(hash[0], password) is True + except argon2.exceptions.VerifyMismatchError: + raise AuthenticationFailure + + @_wrap_serialization_failure + async def _complete_attempt(self, attempt_id: model.UUID, error: str | None) -> model.UUID: + """ + Marks an attempt as complete, returning the page ID. + """ + q = ( + sqlalchemy.update(model.attempts) + .where(model.attempts.c.attempt_id == attempt_id) + .values(finished = True, error = error) + .returning(model.attempts.c.page_id) + ) + res = await self.conn.execute(q) + row = res.first() + if not row: + raise NoSuchThingError + return row[0] + + async def _set_page_status_q(self, page_id: model.UUID, allow_retry: bool) -> int: + """ + Updates a page status after being finished. + Returns the new value of attempts_remaining. + """ + # If a page has been SKIPPED or STASHED, don't return it to READY. + # Writing this out the SQLAlchemy way seems to be borked. I see "expected + # str, got PageStatus" when trying. + # I think this is a bug in SQLAlchemy because in the stack trace, CLAIMED + # supposedly becomes the literal string "CLAIMED". But READY doesn't, it's + # reported as just being a PageStatus. + new_status = sqlalchemy.text("CASE WHEN status = 'CLAIMED' THEN 'READY' ELSE status END") + + values = dict(status = new_status) + if not allow_retry: + values |= dict(attempts_remaining = 0) + q = ( + sqlalchemy.update(model.pages) + .where(model.pages.c.page_id == page_id) + .values(values) + .returning(model.pages.c.attempts_remaining) + ) + res = await self.conn.execute(q) + row = res.first() + assert row + return row[0] + + @_wrap_serialization_failure + async def finish_attempt(self, attempt_id: model.UUID): + """ + Marks an attempt as completed successfully. + """ + page_id = await self._complete_attempt(attempt_id, None) + await self._set_page_status_q(page_id, False) + + @_wrap_serialization_failure + async def fail_attempt(self, attempt_id: model.UUID, error: str, fatal: bool) -> int: + """ + Fails an attempt. + Returns the new value of attempts_remaining. + """ + page_id = await self._complete_attempt(attempt_id, error) + return await self._set_page_status_q(page_id, not fatal) + + @_wrap_serialization_failure + async def create_result(self, attempt_id: model.UUID, result_id: model.UUID, type: model.ResultType, payload: typing.Any): + """ + Creates a result with the given result ID. + Note: If a result already exists with that ID, this method will silently do nothing. + """ + q = sqlalchemy.dialects.postgresql.insert(model.results).on_conflict_do_nothing() + val = dict( + result_id = result_id, + attempt_id = attempt_id, + type = type, + payload = payload, + ) + await self.conn.execute(q, [val]) diff --git a/tracker/common/model.py b/tracker/common/model.py new file mode 100644 index 0000000..744f153 --- /dev/null +++ b/tracker/common/model.py @@ -0,0 +1,235 @@ +from sqlalchemy import ForeignKey, Index, MetaData, Table +import sqlalchemy, sqlalchemy.dialects.postgresql, sqlalchemy.event + +import enum +import typing + +from uuid import UUID + +SCHEMA_VERSION = 1 + +class JobStatus(enum.Enum): + ACTIVE = 0 + + DRAINING = 1 + """The queue is empty, but some in-progress pages remain.""" + + DONE = 2 + + ABORTED = 3 + +class PageStatus(enum.Enum): + READY = 0 + #DEFERRED = 1 + CLAIMED = 2 + SKIPPED = 3 + STASHED = 4 + +class ResultType(enum.Enum): + OUTLINKS = 0 + STATUS_CODE = 1 + FINAL_URL = 2 + REQUISITES = 3 + SCREENSHOT = 4 + CUSTOM_JS_SCREENSHOT = 5 + CUSTOM_JS = 6 + +class JobType(enum.Enum): + BROZZLER = 0 + NONE = -1 + +class Column(sqlalchemy.Column): + inherit_cache = True + + def __init__(self, *args, **kwargs): + kwargs['nullable'] = kwargs.get("nullable", False) + super().__init__(*args, **kwargs) + +metadata_obj = MetaData() + +pipelines = Table( + "pipelines", + metadata_obj, + Column("pipeline_id", sqlalchemy.Text, primary_key = True), + Column("pipeline_secret", sqlalchemy.Text), + Column("matchonly", sqlalchemy.Boolean), +) + +# Pipeline tags, similar to ArchiveBot's substring match for selecting pipelines to run on +# (e.g. select pipelines in Canada) +# ops can create tags as needed for specific tasks if necessary +tags = Table( + "tags", + metadata_obj, + Column("pipeline_id", sqlalchemy.Text, ForeignKey("pipelines.pipeline_id"), primary_key = True), + Column("tag", sqlalchemy.Text, primary_key = True), +) + +jobs = Table( + "jobs", + metadata_obj, + Column("job_id", sqlalchemy.Uuid, primary_key = True), + Column("type", sqlalchemy.Enum(JobType)), + Column("status", sqlalchemy.Enum(JobStatus)), + Column("active_claims", sqlalchemy.SmallInteger, default = 0), + Column("depth", sqlalchemy.Integer, nullable = True), + Column("concurrency", sqlalchemy.SmallInteger), + Column("nice", sqlalchemy.Integer), + Column("tag", sqlalchemy.Text(), nullable = True), + Column("created_by", sqlalchemy.Text), + Column("note", sqlalchemy.Text, nullable = True), + Column("initial_page", sqlalchemy.Text), + Column("metadata", sqlalchemy.dialects.postgresql.JSONB), +) +jobs_dequeue_order = (jobs.c.nice, jobs.c.job_id) +jobs_dequeue_index = Index( + "jobs_dequeue_index", + jobs.c.type, *jobs_dequeue_order, jobs.c.tag, + postgresql_where = (jobs.c.status.in_((JobStatus.ACTIVE, JobStatus.DRAINING))), +) + +# Array schema: JSONB of [scope: str, payload: dict] +job_rulesets = Table( + "job_rulesets", + metadata_obj, + Column("job_ruleset_id", sqlalchemy.Uuid, primary_key = True), + Column("job_id", sqlalchemy.Uuid, ForeignKey("jobs.job_id")), + Column("rules", sqlalchemy.dialects.postgresql.ARRAY(sqlalchemy.dialects.postgresql.JSONB, dimensions = 1, zero_indexes = True)), + + Index("job_rulesets_by_job_id", "job_id"), +) + +class ClaimLock(enum.Enum): + UNTIL_FINISHED = 0 + INDEFINITELY = 1 + +claims = Table( + "claims", + metadata_obj, + Column("pipeline_id", sqlalchemy.Text, ForeignKey("pipelines.pipeline_id"), primary_key = True), + Column("slot", sqlalchemy.SmallInteger, primary_key = True), + Column("job_id", sqlalchemy.Uuid, ForeignKey("jobs.job_id"), nullable = True), + Column("lock", sqlalchemy.Enum(ClaimLock), nullable = True), + + Index("claims_index_by_job", "job_id"), +) + +pages = Table( + "pages", + metadata_obj, + Column("page_id", sqlalchemy.Uuid, primary_key = True), + Column("job_id", sqlalchemy.Uuid, ForeignKey("jobs.job_id")), + Column("payload", sqlalchemy.Text), + # Note! This should be reset to 0 whenever attempts_remaining is manually changed + Column("attempts", sqlalchemy.SmallInteger, default = 0), + Column("attempts_remaining", sqlalchemy.SmallInteger), + Column("nice", sqlalchemy.Integer), + Column("status", sqlalchemy.Enum(PageStatus)), +) + +pages_index_unique = Index("pages_unique_url", pages.c.job_id, pages.c.payload, unique = True) +pages_dequeue_order = (pages.c.nice + pages.c.attempts, pages.c.page_id) +pages_dequeue_filter = (pages.c.status == PageStatus.READY) & (pages.c.attempts_remaining > 0) +pages_dequeue_index = Index( + "pages_dequeue_index", + pages.c.job_id, *pages_dequeue_order, + postgresql_where = pages_dequeue_filter | (pages.c.status == PageStatus.CLAIMED), +) + +relations = Table( + "relations", + metadata_obj, + Column("relation_id", sqlalchemy.Integer, sqlalchemy.Identity(), primary_key = True), + Column("page_id", sqlalchemy.Uuid, sqlalchemy.ForeignKey(pages.c.page_id)), + Column("job_id", sqlalchemy.Uuid, sqlalchemy.ForeignKey("jobs.job_id")), + Column("parent_page", sqlalchemy.Uuid, sqlalchemy.ForeignKey("pages.page_id"), nullable = True), + # Do not add code that modifies this value (or parent_page, or page_id) without adding a new trigger! + Column("depth", sqlalchemy.Integer), + + Index("relations_by_page", "page_id", "depth"), + Index("relations_by_job", "job_id"), + Index("relations_by_parent", "parent_page"), +) +relations_depth_function = sqlalchemy.DDL(""" +CREATE FUNCTION update_relation_depth() RETURNS trigger AS $update_relation_depth$ + DECLARE + existing_depth INTEGER; -- The existing lowest depth for the page. + delta_depth INTEGER; + BEGIN + -- Calculate the current shortest path to any parent relation. + -- If there is no parent, assume depth is 0. + IF NEW.parent_page IS NULL THEN + NEW.depth := 0; + ELSE + SELECT MIN(depth) + 1 INTO NEW.depth FROM relations WHERE page_id = NEW.parent_page; + END IF; + -- Calculate the current shortest path to any relation of this page. + SELECT MIN(depth) INTO existing_depth FROM relations WHERE page_id = NEW.page_id; + -- If existing_depth is NULL, the page has no existing relations, so there is nothing more to do. + IF existing_depth IS NULL THEN + RETURN NEW; + END IF; + -- Calculate the difference between the new relation's depth and the current minimum for the page. + delta_depth := NEW.depth - existing_depth; + -- If it is negative, the existing depth is higher - we've found a shorter path and now need to + -- update the relation's children. Otherwise, there is nothing more to do. + IF delta_depth >= 0 THEN + RETURN NEW; + END IF; + UPDATE relations + SET depth = relations.depth + delta_depth + WHERE relations.job_id = NEW.job_id + AND relations.relation_id IN ( + WITH RECURSIVE CTE (page_id, relation_id) AS ( + SELECT r.page_id, r.relation_id FROM relations AS r WHERE parent_page = NEW.page_id + UNION + SELECT r.page_id, r.relation_id FROM relations AS r + INNER JOIN CTE ON CTE.page_id = r.parent_page + ) + SELECT relation_id FROM CTE + ) + ; + RETURN NEW; + END +$update_relation_depth$ LANGUAGE plpgsql; + """) +relations_depth_trigger = sqlalchemy.DDL( + "CREATE TRIGGER relations_depth_trigger BEFORE INSERT ON relations FOR EACH ROW EXECUTE FUNCTION update_relation_depth();" +) +sqlalchemy.event.listen(relations, "after_create", relations_depth_function) +sqlalchemy.event.listen(relations, "after_create", relations_depth_trigger) + +attempts = Table( + "attempts", + metadata_obj, + Column("attempt_id", sqlalchemy.Uuid, primary_key = True), + Column("page_id", sqlalchemy.Uuid, ForeignKey("pages.page_id")), + Column("pipeline_id", sqlalchemy.Text), + Column("pipeline_version", sqlalchemy.Text), + Column("error", sqlalchemy.Text, nullable = True, default = None), + Column("finished", sqlalchemy.Boolean, default = False), + Column("ruleset_id", sqlalchemy.Uuid), + + # TODO: Does attempt_id need to be explicitly stated here? + Index("attempts_index_by_page", "page_id", "attempt_id"), +) + +# Page results +results = Table( + "results", + metadata_obj, + Column("result_id", sqlalchemy.Uuid, primary_key = True), + Column("attempt_id", sqlalchemy.Uuid, ForeignKey("attempts.attempt_id")), + Column("type", sqlalchemy.Enum(ResultType)), + Column("payload", sqlalchemy.dialects.postgresql.JSONB), + + Index("results_by_attempt", "attempt_id"), +) + +# Global options +options = sqlalchemy.Table( + "options", + metadata_obj, + Column("key", sqlalchemy.Text, primary_key = True), + Column("value", sqlalchemy.Text), +) diff --git a/tracker/dashboard/__main__.py b/tracker/dashboard/__main__.py new file mode 100644 index 0000000..027c2cf --- /dev/null +++ b/tracker/dashboard/__main__.py @@ -0,0 +1,407 @@ +from quart import Quart, abort, redirect, render_template, render_template_string, request, url_for +import werkzeug.exceptions +import os +import base64 +import dataclasses + +import sqlalchemy, sqlalchemy.ext.asyncio, sqlalchemy.dialects.postgresql + +from ..common import db, model + +class EscapingQuart(Quart): + def select_jinja_autoescape(self, filename: str) -> bool: + return (not filename) or filename.endswith(".j2") or super().select_jinja_autoescape(filename) + +app = EscapingQuart(__name__) +app.jinja_env.globals.update(isinstance = isinstance) + +DOCUMENTATION_URL = os.getenv("DOCUMENTATION_URL") + +NAV = ( + ("/", "Dashboard"), + #("/claims", "Claims"), + #("/pipelines", "Pipelines"), + ("/docs", "Documentation"), +) + +async def _setup_engine(): + global ENGINE + ENGINE = await db.create_engine() + +app.before_serving(_setup_engine) + +@dataclasses.dataclass +class RulesetInfoPacket: + id: model.UUID + rules: list[db.JobRule] + +@dataclasses.dataclass +class JobInfoPacket: + id: model.UUID + type: model.JobType + status: model.JobStatus + active_claims: list[str] + depth: int | None + concurrency: int + nice: int + tag: str | None + note: str | None + initial_page: str + ruleset: RulesetInfoPacket + +@dataclasses.dataclass +class ResultsInfoPacket: + screenshot: model.UUID | None = None + cjs_screenshot: model.UUID | None = None + outlinks: list[str] | None = None + requisites: list | None = None + status_code: int | None = None + final_url: str | None = None + +@dataclasses.dataclass +class AttemptInfoPacket: + id: model.UUID + page_id: model.UUID + pipeline_id: str + pipeline_version: str + error: str | None + finished: bool + ruleset: RulesetInfoPacket + applied_settings: db.PageSettings + results: ResultsInfoPacket + +@dataclasses.dataclass +class PageInfoPacket: + id: model.UUID + job_id: model.UUID + url: str + status: model.PageStatus + attempts: int + attempts_remaining: int + nice: int + + all_attempts: list[AttemptInfoPacket] + +@app.context_processor +def aaa(): + return {"nav": NAV, "len": len} + +@app.route("/") +async def home(): + async with ENGINE.connect() as conn: + conn = await conn.execution_options(postgresql_readonly = True) + q = ( + sqlalchemy.select(model.jobs.c.job_id, model.jobs.c.initial_page, model.jobs.c.note) + .where(model.jobs.c.type == model.JobType.BROZZLER) + .where(model.jobs.c.status.in_((model.JobStatus.ACTIVE, model.JobStatus.DRAINING))) + .order_by(*model.jobs_dequeue_order) + ) + active_jobs = (await conn.execute(q)).all() + return await render_template("home.j2", jobs = active_jobs) + +@app.route("/docs") +async def docs(): + if DOCUMENTATION_URL: + return redirect(DOCUMENTATION_URL, 302) + return await render_template("error.j2", reason = "No documentation URL", description = "The DOCUMENTATION_URL environment variable was not set. Please report this!"), 500 + +@app.route("/item/translate") +async def translate_form_input(): + if "item" not in request.args: + abort(400) + id = request.args['item'] + async with ENGINE.connect() as conn: + conn = await conn.execution_options(postgresql_readonly = True) + q = sqlalchemy.select( + sqlalchemy.exists(sqlalchemy.select(model.jobs).where(model.jobs.c.job_id == id)), + sqlalchemy.exists(sqlalchemy.select(model.pages).where(model.pages.c.page_id == id)) + ) + res = await conn.execute(q) + is_job, is_page = res.one() + if is_job: + if is_page: + return await render_template("error.j2", reason = "Ambiguous ID", description = "The ID you provided was found as both a job ID and a page ID. Please report this!"), 500 + return redirect(url_for("single_job", job_id = id)) + elif is_page: + return redirect(url_for("single_page", page_id = id)) + return await render_template("error.j2", reason = "No such ID", description = "No job or page was found with the provided ID."), 404 + +def route_with_json(route, **kwargs): + """ + Adds app.route for route and route + ".json". + The callback should take an argument called html, which indicates whether or not to return HTML. + """ + assert "defaults" not in kwargs + def inner(cb): + html_cb = app.route( + route, + defaults = {"html": True}, + **kwargs + )(cb) + return app.route( + route + ".json", + defaults = {"html": False}, + **kwargs + )(html_cb) + return inner + +@route_with_json("/claims") +async def claims(html): + claims = QUEUE.claimed() + if html: + return await render_template("pending.j2", pending = claims, adj = "Claimed") + claims = [] + async for item in QUEUE.claimed(): + claims.append(item) + return {"status": 200, "claims": claims} + +@route_with_json("/page/") +async def single_page(page_id, html): + q = sqlalchemy.select(model.pages).where(model.pages.c.page_id == page_id) + attempt_q = ( + sqlalchemy.select(model.attempts, model.job_rulesets.c.rules) + .select_from(model.attempts) + .join(model.job_rulesets, model.attempts.c.ruleset_id == model.job_rulesets.c.job_ruleset_id) + .where(model.attempts.c.page_id == page_id) + ) + + async with ENGINE.connect() as conn: + conn = await conn.execution_options(postgresql_readonly = True) + res = await conn.execute(q) + row = res.one_or_none() + if row is None: + if html: + return await render_template("error.j2", reason = f"Page ID {page_id} not found", description = f"No page with this ID exists.", show_item_search = True), 404 + return {"status": 404, "message": "Page ID not found"}, 404 + page_packet = PageInfoPacket(page_id, row.job_id, row.payload, row.status, row.attempts, row.attempts_remaining, row.nice, []) + + attempts_res = await conn.execute(attempt_q) + for row in attempts_res: + # TODO: Flatten this into the attempt_q query. + results_q = sqlalchemy.select(model.results).where(model.results.c.attempt_id == row.attempt_id) + results_r = await conn.execute(results_q) + results = ResultsInfoPacket() + for result in results_r: + match result.type: + case model.ResultType.CUSTOM_JS_SCREENSHOT: + results.cjs_screenshot = result.result_id + case model.ResultType.FINAL_URL: + results.final_url = result.payload + case model.ResultType.OUTLINKS: + results.outlinks = result.payload + case model.ResultType.REQUISITES: + results.requisites = result.payload + case model.ResultType.SCREENSHOT: + results.screenshot = result.result_id + case model.ResultType.STATUS_CODE: + results.status_code = result.payload + rules = [db.JobRule(*i) for i in row.rules] + applied_settings = db.Connection.compute_page_settings(rules, page_packet.url) + page_packet.all_attempts.append(AttemptInfoPacket( + id = row.attempt_id, + page_id = page_id, + pipeline_id = row.pipeline_id, + pipeline_version = row.pipeline_version, + error = row.error, + finished = row.finished, + ruleset = RulesetInfoPacket(row.ruleset_id, rules), + applied_settings = applied_settings, + results = results, + )) + if html: + return await render_template("page.j2", page = page_packet) + v = dataclasses.asdict(page_packet) + v['status'] = v['status'].name + return {"status": 200, "page": v} + +@route_with_json("/job/") +async def single_job(job_id, html): + claim_q = ( + sqlalchemy.select(sqlalchemy.dialects.postgresql.array_agg(sqlalchemy.text("claims.*"))) + .where(model.claims.c.job_id == job_id) + .scalar_subquery() + ) + ruleset_q = ( + sqlalchemy.select(model.job_rulesets.c.job_ruleset_id, model.job_rulesets.c.rules) + .where(model.job_rulesets.c.job_id == job_id) + .order_by(model.job_rulesets.c.job_ruleset_id.desc()) + .limit(1) + .subquery() + ) + + q = ( + sqlalchemy.select( + model.jobs, + claim_q.label("all_claims"), + ruleset_q.c.job_ruleset_id, + ruleset_q.c.rules, + ) + .select_from(model.jobs) + .where(model.jobs.c.job_id == job_id) + .join(ruleset_q, sqlalchemy.true(), isouter = True) + ) + async with ENGINE.connect() as conn: + conn = await conn.execution_options(postgresql_readonly = True) + res = await conn.execute(q) + row = res.one_or_none() + if row is None: + if html: + return await render_template("error.j2", reason = f"Job ID {job_id} not found", description = "No job with this ID exists.", show_item_search = True), 404 + return {"status": 404, "message": "Job ID not found"}, 404 + ruleset = RulesetInfoPacket(row.job_ruleset_id, [db.JobRule(*rule) for rule in row.rules]) + packet = JobInfoPacket( + id = row.job_id, + type = row.type, + status = row.status, + depth = row.depth, + concurrency = row.concurrency, + nice = row.nice, + tag = row.tag, + note = row.note, + initial_page = row.initial_page, + ruleset = ruleset, + active_claims = row.all_claims or [], + ) + if html: + return await render_template("job.j2", job = packet) + v = dataclasses.asdict(packet) + v['status'] = v['status'].name + v['type'] = v['type'].name + return {"status": 200, "job": v} + +async def pages_list(q, html, job_id, volatile = True, use_status = False): + page_size = 10 + try: + offset = int(request.args.get("offset", 0)) + assert offset >= 0 + except (ValueError, AssertionError): + if html: + return await render_template("error.j2", reason = "Bad request", description = "Offset parameter was invalid."), 400 + return {"status": 400, "error": "Offset parameter was invalid."} + q = q.offset(offset).limit(page_size) + + rows = [] + async with ENGINE.connect() as conn: + conn = await conn.execution_options(postgresql_readonly = True) + res = await conn.stream(q) + async for row in res: + if use_status: + page_id, payload, status, remaining = row + rows.append(dict(page_id = page_id, payload = payload, status = status.name, remaining = remaining)) + else: + page_id, payload, attempt_count = row + rows.append(dict(page_id = page_id, payload = payload, attempt_count = attempt_count)) + + next_offset = None + prev_offset = max(offset - page_size, 0) if offset > 0 else None + if len(rows) >= page_size: + next_offset = offset + page_size + if html: + return await render_template("pages.j2", rows = rows, offset = offset, next_offset = next_offset, prev_offset = prev_offset, job_id = job_id, volatile = volatile, use_status = use_status) + return {"status": 200, "rows": rows, "next": next_offset, "prev": prev_offset} + +@route_with_json("/job//pending") +async def job_pending(job_id, html): + q = ( + sqlalchemy.select(model.pages.c.page_id, model.pages.c.payload, sqlalchemy.func.count(model.attempts.c.page_id)) + .select_from(model.pages) + .join(model.attempts, model.attempts.c.page_id == model.pages.c.page_id, isouter = True) + .where(model.pages.c.job_id == job_id) + .where(model.pages_dequeue_filter) + .group_by(model.pages.c.page_id) + .order_by(*model.pages_dequeue_order) + ) + return await pages_list(q, html, job_id) + +@route_with_json("/job//claimed") +async def job_claimed(job_id, html): + q = ( + sqlalchemy.select(model.pages.c.page_id, model.pages.c.payload, sqlalchemy.func.count(model.attempts.c.page_id)) + .select_from(model.pages) + .join(model.attempts, model.attempts.c.page_id == model.pages.c.page_id, isouter = True) + .where(model.pages.c.job_id == job_id) + .where(model.pages.c.status == model.PageStatus.CLAIMED) + .group_by(model.pages.c.page_id) + .order_by(*model.pages_dequeue_order) + ) + return await pages_list(q, html, job_id) + +@route_with_json("/job//pages") +async def job_pages(job_id, html): + q = ( + sqlalchemy.select(model.pages.c.page_id, model.pages.c.payload, model.pages.c.status, model.pages.c.attempts_remaining) + .select_from(model.pages) + .where(model.pages.c.job_id == job_id) + .order_by(model.pages.c.page_id) + ) + return await pages_list(q, html, job_id, volatile = False, use_status = True) + +@route_with_json("/ruleset///test") +async def test_ruleset(job_id, ruleset_id, html): + url = request.args['url'] + async with ENGINE.connect() as conn: + conn = await conn.execution_options(postgresql_readonly = True) + queue = db.Connection(conn) + latest_ruleset = (await queue.get_job_ruleset(job_id))[0] + warning = "Warning: You are not querying the latest ruleset.
" if str(latest_ruleset) != ruleset_id else "" + ruleset = await queue.get_ruleset(ruleset_id) + settings = queue.compute_page_settings(ruleset, url) + if html: + return await render_template_string( + '{{ warning|safe }} URL: {{ url }}
{% import "macros.j2" as macros %} {{ macros.build_settings(settings) }}', + settings = settings, + url = url, + warning = warning, + ) + return {"status": 200, "settings": settings} + +@app.route("/item//requisites") +async def requisites(id): + item = await QUEUE.get(id) + if not item: + return "", {"content-type": "text/plain"} + return get_requisites(item), {"content-type": "text/plain"} + +async def get_requisites(item): + async for result in QUEUE.get_results(item): + if result.type == "requisites": + for requisite in result.data: + for entry in requisite['chain']: + if req := entry['request']: + if req['url'].startswith("http"): + yield req['url'] + "\n" + yield "\nEOF" + +@app.route("/item//outlinks") +async def outlinks(id): + item = await QUEUE.get(id) + if not item: + return "", {"content-type": "text/plain"} + return get_outlinks(item), {"content-type": "text/plain"} + +async def get_outlinks(item): + async for result in QUEUE.get_results(item): + if result.type == "outlinks": + for outlink in result.data: + if outlink.startswith("http"): + yield outlink + "\n" + yield "\nEOF" + +@app.route("/screenshot//full.jpg") +@app.route("/screenshot//thumb.jpg") +@app.route("/screenshot/.jpg") +async def screenshot(id): + q = sqlalchemy.select(model.results.c.payload).where(model.results.c.result_id == id) + async with ENGINE.connect() as conn: + conn = await conn.execution_options(postgresql_readonly = True) + res = await conn.scalar(q) + if not res: + return await render_template("error.j2", code = 404, reason = "Screenshot not found", description = "Screenshot was not found.") + return base64.b85decode(res), {"Content-Type": "image/jpeg"} + +@app.errorhandler(werkzeug.exceptions.HTTPException) +async def error(e: werkzeug.exceptions.HTTPException): + if request.accept_mimetypes.accept_json: + return await render_template("error.j2", code = e.code, reason = e.name, description = e.description), e.code + return {"status": e.code, "message": f"{e.name}: {e.description}"}, e.code + diff --git a/tracker/dashboard/static/styles.css b/tracker/dashboard/static/styles.css index 81ced4a..1bb6371 100644 --- a/tracker/dashboard/static/styles.css +++ b/tracker/dashboard/static/styles.css @@ -1,4 +1,4 @@ -.active { +.active, .low-importance, nav .current { color: grey; } @@ -20,8 +20,9 @@ hgroup a { text-decoration: none; } -nav .current { - color: grey; +/* Prevent there from being too much space above the breadcrumb */ +main nav[aria-label="breadcrumb"] li { + padding-top: 0 !important; } a code { @@ -49,5 +50,32 @@ section:focus-visible > :first-child { } .thumb { - max-width: 60%; + max-width: 40%; +} + +/* These are basically Pico's accordion styles but adapted so they are usable wherever. */ +.link-box { + display: block; + margin-bottom: var(--pico-spacing); + width: 100%; + text-decoration: none; +} + +.link-box-inner { + width: 100%; + text-align: left; +} + +.link-box-inner::after { + display: block; + width: 1rem; + height: calc(1rem * var(--pico-line-height,1.5)); + float: right; + transform: rotate(-90deg); + background-image: var(--pico-icon-chevron); + background-position: right center; + background-size: 1rem auto; + background-repeat: no-repeat; + content: ""; + transition: transform var(--pico-transition); } diff --git a/tracker/dashboard/templates/components/item-search.j2 b/tracker/dashboard/templates/components/item-search.j2 index ecd7edd..8fda36a 100644 --- a/tracker/dashboard/templates/components/item-search.j2 +++ b/tracker/dashboard/templates/components/item-search.j2 @@ -4,7 +4,7 @@ function validateForm() { const e = document.getElementById("search-item"); e.value = e.value.trim(); const err = document.getElementById("search-form").querySelector(".error"); - const isValid = /^[0-9A-F]{8}-[0-9A-F]{4}-[4][0-9A-F]{3}-[89AB][0-9A-F]{3}-[0-9A-F]{12}$/i.test(e.value); + const isValid = /^[0-9A-F]{8}-[0-9A-F]{4}-[47][0-9A-F]{3}-[89AB][0-9A-F]{3}-[0-9A-F]{12}$/i.test(e.value); if (isValid) { e.ariaInvalid = false; err.hidden = true; diff --git a/tracker/dashboard/templates/home.j2 b/tracker/dashboard/templates/home.j2 index 68efc75..db15389 100644 --- a/tracker/dashboard/templates/home.j2 +++ b/tracker/dashboard/templates/home.j2 @@ -4,23 +4,31 @@ {% block content %} -
-

Queues

-

See Pending for queued items, or Claims for in-progress items.

- {% if status %} - {% for pipeline, queues in status.items() %} -

Pipeline {{ pipeline }}:

+
+

Dashboard

+

Sorry, live update is currently not implemented. Refresh the page for current statistics.

+ {% for job_id, initial_page, note in jobs %} +
+
+

{{ initial_page }}

+

+ {% if note %} {{ note }} {% endif %} +

+
+

X pages in the queue.

+

The following pages are currently claimed:

+
    - {% for queue, value in queues.items() %} -
  • - {{ value }} {{ queue }} items. -
  • - {% endfor %}
- {% endfor %} - {% else %} -

There aren't any queued, running, or stashed items.

- {% endif %} +
+

+ {{ job_id }} +

+
+
+ {% endfor %}
{% include "components/item-search.j2" %} diff --git a/tracker/dashboard/templates/item.j2 b/tracker/dashboard/templates/item.j2 deleted file mode 100644 index ffc18e1..0000000 --- a/tracker/dashboard/templates/item.j2 +++ /dev/null @@ -1,206 +0,0 @@ -{% extends "skeleton.j2" %} - -{% block title %} Item {{ item.id }} ({{ item.item }}) {% endblock %} - -{% block content %} - -

Item {{ item.id }}

- -
-

Item data

-
    -
  • Item URL: {{ item.item }}
  • -
  • Item status: - {{ item.status }} - {% if item.stash %} - {% if item.status == "stashed" %} (in - {% else %} (from - {% endif %} stash {{ item.stash }}) - {% endif %} -
  • -
  • Queued by: {{ item.queued_by }}
  • -
  • - User agent: - {% if item.metadata.ua == "default" %} Default - {% elif item.metadata.ua == "stealth" %} Stealth - {% elif item.metadata.ua == "minimal" %} Minimal - {% elif item.metadata.ua == "googlebot" %} Googlebot Desktop - {% elif item.metadata.ua.startswith("$") %} {{ item.metadata.ua[1:] }} - {% endif %} -
  • - {% if item.metadata['custom_js'] %} -
  • -
    - Custom JavaScript: Click to show -
    {{ item.metadata['custom_js'] }}
    -
    -
  • - {% else %} -
  • Custom JavaScript: None
  • - {% endif %} -
  • Pipeline type: {{ item.pipeline_type }}
  • -
  • Priority: {{ item.priority }} (effective {{ item.priority + item.attempt_count() }})
  • -
  • Queued at: {{ item.queued_at.replace(microsecond=0) }}
  • - {% if item.claimed_at %} -
  • Last claimed at: {{ item.claimed_at.replace(microsecond=0) }}
  • - {% endif %} - {% if item.finished_at %} -
  • Finished at: {{ item.finished_at.replace(microsecond=0) }}
  • - {% endif %} -
  • Tries: {{ item.attempt_count() }}
  • - {% if error_reasons %} -
  • - Error reasons: -
      - {% for reason in item.error_reasons %} - {% endfor %} -
    -
  • - {% endif %} - {% if item.expires %} -
  • Expires: {{ item.expires.replace(microsecond=0) }}
  • - {% endif %} - {% if item.explanation %} -
  • Explanation: {{ item.explanation }}
  • - {% endif %} - {% if item.claimed_by %} -
  • Last claimed by pipeline: {{ item.claimed_by }}
  • - {% elif item.run_on %} -
  • Must run on pipeline: {{ item.run_on }}
  • - {% endif %} -
-
- -
-

Results

- {% if results %} - {% for try, try_results in results.items() %} -

- - - Attempt {{ try + 1 }} - - (on pipeline {{ item.attempts[try].pipeline }} {{ item.attempts[try].pipeline_version or "" }}) - -

-
    - {% for result in try_results %} -
  • -
    - {{ result.type }} (click to show) -
    - {% if result.type == "outlinks" %} - - {% elif result.type in ("screenshot", "cjs_screenshot") %} - - Page screenshot -
    - Click to enlarge image -
    - {% elif result.type == "requisites" %} - - - - - - - - - - {% for requisite in result.data %} - {% for pair in requisite['chain'] %} - - {% if pair.response._type == "Response" %} - - - {% elif pair.response._type == "Error" %} - - - {% elif pair.response is none %} - - - {% else %} - - - {% endif %} - {% if loop.first %} - - {% else %} - - {% endif %} - {% if pair.response.length is none %} - - {% elif pair.response.length is not defined %} - - {% else %} - - {% endif %} - - - - {% endfor %} - {% endfor %} -
    StatusTypeCategoryLengthMethodURL
    {{ pair.response.status[0] }} {{ pair.response.status[1] }}{{ pair.response.mimetype or "-" }}Error{{ pair.response.text }}NoneNo responseUnknown - this is a bug{{ pair.request.category }}Redirect ↑?n/a{{ pair.response.length }}{{ pair.request.method }} - - {{ pair.request.url }} - -
    - {% else %} - - {{ result.data }} - - {% endif %} -
    -
    -
  • - {% else %} - {% if not item.attempts[try].error %} -

    No results exist yet for this attempt. It may still be running.

    - {% endif %} - {% endfor %} - {% if item.attempts[try].error %} -
  • -
    - Error: {{ item.attempts[try].error.split('\n')[0].strip() }} {% if '\n' in reason %}(click for full message) {% endif %} -
    {{ item.attempts[try].error }}
    -
    -
  • - {% endif %} - {% if item.attempts[try].poke_reason %} -
  • -
    - Admin poke reason: {{ item.attempts[try].poke_reason.split('\n')[0].strip() }} {% if '\n' in reason %}(click for full message) {% endif %} -
    {{ item.attempts[try].poke_reason }}
    -
    -
  • - {% endif %} -
- {% endfor %} - {% else %} -

This item has not been tried yet.

- {% endif %} -
- -
-

Uploads

- -

Upload indexing is currently not implemented.

-
- - - -{% include "components/item-search.j2" %} - -{% endblock %} diff --git a/tracker/dashboard/templates/job.j2 b/tracker/dashboard/templates/job.j2 new file mode 100644 index 0000000..290ce01 --- /dev/null +++ b/tracker/dashboard/templates/job.j2 @@ -0,0 +1,144 @@ +{% extends "skeleton.j2" %} + +{% block title %} Job {{ job.id }} ({{ job.initial_page }}) {% endblock %} + +{% block content %} + +

Job {{ job.id }}

+ +{% set status = job.status.name %} +
+

Job data

+
    +
  • Initial URL: {{ job.initial_page }}
  • +
  • Job status: + + {{ status }} + +
  • +
  • Nice: {{ job.nice }}
  • +
  • Created at:
  • +
  • Max depth: {{ job.depth }}
  • +
  • Max concurrency: {{ job.concurrency }} ({{ job.active_claims|length }} active claims)
  • + {% if job.explanation %} +
  • Explanation: {{ job.explanation }}
  • + {% endif %} + {% if job.tag %} +
  • + Tag constraint: + {{ job.tag }} (list) +
  • + {% endif %} + {% if job.active_claims %} +
  • Active claims: {{ job.active_claims | map(attribute = "pipeline_id") | join(",") }}
  • + {% else %} +
  • Currently unclaimed.
  • + {% endif %} +
+
+ +
+

Settings

+

Job rules are applied in order. The last instance of a setting wins.

+

For more information about the job rules system, see the documentation.

+
    + {% for rule in job.ruleset.rules %} +
  • + {% if rule.scope == "" %} + Base rule (applies to all pages; index {{ loop.index0 }}): + {% else %} + All pages matching {{ rule.scope }} (index {{ loop.index0 }}): + {% endif %} + {{ macros.build_settings(rule.payload) }} +
  • + {% endfor %} +
+

(Ruleset ID: {{ job.ruleset.id }})

+
+

Rule tester

+

+ You can test how a particular URL will be handled with this tool. + The job rules system uses the Python regex module with the default settings. + For accurate results, ensure you use the full URL (including e.g. protocol). +

+

+ Currently-loaded ruleset ID: {{ job.ruleset.id }} +

+
+
+ + + +
+
+
+ +
+
+ +
+

Pages

+
+

Claimed Pages

+ +
+
+

Pending Pages

+ +
+
+

All Pages

+ List all pages (in order of creation) +
+
+ +{# +
+

Uploads

+ +

Upload indexing is currently not implemented.

+
+#} + + + +{% include "components/item-search.j2" %} + +{% endblock %} diff --git a/tracker/dashboard/templates/macros.j2 b/tracker/dashboard/templates/macros.j2 new file mode 100644 index 0000000..0ae1cb1 --- /dev/null +++ b/tracker/dashboard/templates/macros.j2 @@ -0,0 +1,60 @@ +{% macro build_ua(ua) %} + {% if ua == "default" %} Default + {% elif ua == "stealth" %} Stealth + {% elif ua == "minimal" %} Minimal + {% elif ua == "googlebot" %} Googlebot Desktop + {% elif ua.startswith("$") %} {{ ua[1:] }} + {% endif %} +{% endmacro %} + +{% macro build_settings(settings) %} +
    + {% if settings.ua is defined %} +
  • User agent: {{ build_ua(settings.ua) }}
  • + {% endif %} + {% if settings.custom_js is defined %} + {% if settings.custom_js %} +
  • +
    + Custom JavaScript: Click to show +
    {{ settings.custom_js }}
    +
    +
  • + {% else %} +
  • Custom JavaScript: None
  • + {% endif %} + {% endif %} + {% if settings.skip is defined %} + {% if settings.skip %} +
  • Skip this page
  • + {% else %} +
  • Do not skip this page
  • + {% endif %} + {% endif %} + {% if settings.accept is defined %} +
  • + {% if settings.accept %} + Accept this URL into the queue + {% else %} + Reject this URL from the queue + {% endif %} +
  • + {% endif %} +
+{% endmacro %} + +{% macro page_status(status, attempts_remaining) %} + {% if status == "DRAINING" %} + + DRAINING + + {% elif status == "READY" %} + {% if not attempts_remaining %} + COMPLETED + {% else %} + TODO + {% endif %} + {% else %} + {{ status }} + {% endif %} +{% endmacro %} diff --git a/tracker/dashboard/templates/page.j2 b/tracker/dashboard/templates/page.j2 new file mode 100644 index 0000000..84fb0f1 --- /dev/null +++ b/tracker/dashboard/templates/page.j2 @@ -0,0 +1,185 @@ +{% extends "skeleton.j2" %} + +{% block title %} Page {{ page.id }} ({{ page.url }}) {% endblock %} + +{% block content %} +{% set status = page.status.name %} + + + +

{{ page.url }}

+ +
+
    +
  • Page status: {{ macros.page_status(page.status.name, page.attempts_remaining) }}
  • +
  • Nice: {{ page.nice }} (effective {{ page.nice + page.attempts }})
  • +
  • Tries remaining: {{ page.attempts_remaining }} ({{ page.attempts }} attempts since last reset)
  • +
+
+ +
+

Results

+ {% if page.all_attempts %} + {% for attempt in page.all_attempts if attempt.finished %} +

+ + + Attempt {{ attempt.id }} + + (on pipeline {{ attempt.pipeline }} {{ attempt.pipeline_version }}) + +

+
    +
  • + Applied settings: + {{ macros.build_settings(attempt.applied_settings) }} +
  • + {% if attempt.error %} +
  • +
    + + Error: + {{ attempt.error.split('\n')[0].strip() }} + +
    {{ attempt.error }}
    +
    +
  • + {% endif %} +
  • + Status code: {{ attempt.results.status_code }} +
  • +
  • + Final URL: {{ attempt.results.final_url }} +
  • +
  • + Screenshot: +
    + {% if attempt.results.screenshot %} + + Page screenshot +
    + Click to enlarge image +
    + {% else %} + No screenshot was recorded. + {% endif %} +
  • + {% if attempt.results.cjs_screenshot %} +
  • + Custom JS screenshot: +
    + + Page screenshot +
    + Click to enlarge image +
    +
  • + {% endif %} + +
  • +
    + Requisites: +
    + + + + + + + + + + {% for requisite in attempt.results.requisites %} + {% for pair in requisite['chain'] %} + + {% if pair.response._type == "Response" %} + + + {% elif pair.response._type == "Error" %} + + + {% elif pair.response is none %} + + + {% else %} + + + {% endif %} + {% if loop.first %} + + {% else %} + + {% endif %} + {% if pair.response.length is none %} + + {% elif pair.response.length is not defined %} + + {% else %} + + {% endif %} + + + + {% endfor %} + {% endfor %} +
    StatusTypeCategoryLengthMethodURL
    {{ pair.response.status[0] }} {{ pair.response.status[1] }}{{ pair.response.mimetype or "-" }}Error{{ pair.response.text }}NoneNo responseUnknown - this is a bug{{ pair.request.category }}Redirect ↑?n/a{{ pair.response.length }}{{ pair.request.method }} + + {{ pair.request.url }} + +
    +
    +
    +
  • +
+ {% endfor %} + {% for attempt in page.all_attempts if not attempt.finished %} +

Attempt {{ attempt.id }} has not been completed.

+ {% endfor %} + {% else %} +

This item has not been tried yet.

+ {% endif %} +
+ +{# +
+

Uploads

+ +

Upload indexing is currently not implemented.

+
+#} + + + +{% include "components/item-search.j2" %} + +{% endblock %} diff --git a/tracker/dashboard/templates/pages.j2 b/tracker/dashboard/templates/pages.j2 new file mode 100644 index 0000000..b86feb3 --- /dev/null +++ b/tracker/dashboard/templates/pages.j2 @@ -0,0 +1,57 @@ +{% extends "skeleton.j2" %} + +{% block title %} mnbot pages iframe {% endblock %} + +{% block override_header %} {% endblock %} + +{% block content %} + + + +
+
+ Reload +
+ +
+ +
+ {% for row in rows %} + + + + {% endfor %} + {% if not offset and not rows %} +

There aren't any pages that match this query.

+ {% endif %} +
+ +{% endblock %} diff --git a/tracker/dashboard/templates/skeleton.j2 b/tracker/dashboard/templates/skeleton.j2 index 4d6e414..aa74263 100644 --- a/tracker/dashboard/templates/skeleton.j2 +++ b/tracker/dashboard/templates/skeleton.j2 @@ -1,10 +1,11 @@ + {% import "macros.j2" as macros %} - {% block title %} why do they call it oven when you of in the cold food of out hot eat the food {% endblock %} | mnbot + {% block title %} Define the title block! {% endblock %} | mnbot + {% block override_header %}

mnbot

@@ -50,6 +84,7 @@
+ {% endblock %}
{% block content %}

Congrats! You broke the template!

diff --git a/tracker/irc/__main__.py b/tracker/irc/__main__.py new file mode 100644 index 0000000..321b47f --- /dev/null +++ b/tracker/irc/__main__.py @@ -0,0 +1,380 @@ +import os +import asyncio + +import aiohttp +import datetime +import validators +import traceback + +import sqlalchemy + +from bot2h import Bot, Colour, Format, User + +from ..common import db, model + +H2IBOT_GET_URL = os.environ['H2IBOT_GET_URL'] +H2IBOT_POST_URL = os.environ['H2IBOT_POST_URL'] +TRACKER_BASE_URL = os.environ['TRACKER_BASE_URL'].rstrip("/") +DOCUMENTATION_URL = os.environ['DOCUMENTATION_URL'] +MNBOT_HEADER = "//! mnbot v1" # DO NOT ADD \n or \r\n here. + +def is_mnbot_js(payload): + """ + Returns true if the payload is a mnbot CJS header. + """ + if not payload.startswith(MNBOT_HEADER): return False + def validate_suffix(payload, suffix): + return payload[len(MNBOT_HEADER):len(MNBOT_HEADER)+len(suffix)] == suffix + return validate_suffix(payload, "\n") or validate_suffix(payload, "\r\n") + +def item_url(id): + if not id: + return "" + return f"{TRACKER_BASE_URL}/job/{id}" + +bot = Bot(H2IBOT_GET_URL, H2IBOT_POST_URL, max_coros = 1) + +PRESET_USER_AGENTS = { + "curl": "curl/7.88.1", + "archivebot": ( + "ArchiveTeam ArchiveBot/20240923.203d40a (wpull 2.0.3) and not Mozilla/5.0 " + "(Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) " + "Chrome/42.0.2311.90 Safari/537.36" + ), + "googlebot1": ( + "Mozilla/5.0 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)" + ), +} + +class ValidationError(Exception): + def __init__(self, url): + self.url = url + +@bot.add_argument("--concurrency", "-c", type = int, default = 1) +@bot.add_argument( + "--user-agent", "-u", + choices = ("default", "stealth", "curl", "archivebot", "minimal", "googlebot1", "googlebot"), + default = "default" +) +@bot.add_argument("--explanation", "--explain", "-e") +@bot.add_argument("--custom-js") +@bot.add_argument("--skip-url-validation", action = "store_true") +@bot.add_argument("--nice", "-n", type = int, default = 0) +@bot.add_argument("url") +@bot.add_argument("redirector", nargs = "?", metavar = "<") +@bot.argparse("!brozzle") +@bot.command({"!b", "!brozzle"}, required_modes="+@") +async def brozzle(self: Bot, user: User, ran, args): + metadata = {} + if args.nice < -10: + if "@" not in user.modes: + yield "Sorry, but only operators can queue with a niceness lower than -10." + return + if args.skip_url_validation: + if "@" not in user.modes: + yield "Sorry, but only operators can bypass URL validation." + return + + initial_settings = db.PageSettings() + + if args.custom_js: + if "@" not in user.modes: + yield "Sorry, but only operators can use custom JavaScript." + return + try: + async with AIOHTTP_SESSION.get(args.custom_js) as resp: + if resp.status != 200: + yield f"Failed to retrieve custom JS! Got status {resp.status} (expected 200)." + return + custom_js = await resp.text() + if not is_mnbot_js(custom_js): + yield "Error: Custom JS must start with a valid mnbot header." + return + initial_settings.custom_js = custom_js + except Exception as e: + yield f"Failed to retrieve custom JS ({type(e)} was raised)." + print("Custom JS exception:") + traceback.print_exc() + return + if ua := PRESET_USER_AGENTS.get(args.user_agent): + initial_settings.ua = "$" + ua + elif args.user_agent != "default": + initial_settings.ua = args.user_agent + + job_id = db.generate_id() + job = db.JobCreation( + job_id = job_id, + type = model.JobType.BROZZLER, + created_by = user.nick, + metadata = metadata, + initial_page = args.url, + concurrency = args.concurrency, + nice = args.nice, + note = args.explanation, + ) + + async with ENGINE.connect() as conn: + queue = db.Connection(conn) + + await queue.create_jobs([job]) + await queue.create_job_rule(job_id, None, db.JobRule("", initial_settings.as_dict())) + + num_urls = 0 + async def submit_page_batch(pages): + nonlocal num_urls + data = [] + for page in pages: + if not args.skip_url_validation: + result = validators.url(page, strict_query = False, private = False) + if result is not True: + raise ValidationError(page) + data.append(db.PageCreation(page_id = db.generate_id(), payload = page, parent_page = None)) + res = await queue.create_pages(job_id, data) + # create_pages returns a mapping of {given_id: actual_id}. + # If given_id != actual_id, the page was already in the database. + # (In this case, that means there was duplication in the list.) + num_urls += sum(1 for key, value in res.items() if key == value) + + if args.redirector == "<": + try: + async with AIOHTTP_SESSION.get(args.url) as resp: + if resp.status != 200: + yield f"Failed to retrieve URL list! Got status {resp.status} (expected 200)." + return + buf = [] + async for url in resp.content: + url = url.decode().rstrip("\r\n") + if not url: + continue + buf.append(url) + if len(buf) >= 50: + await submit_page_batch(buf) + buf = [] + if buf: + await submit_page_batch(buf) + buf = [] + except ValidationError as e: + raise + except Exception as e: + yield f"Failed to retrieve the URL list ({type(e)} was raised)." + print("Retrieval exception:") + traceback.print_exc() + return + if num_urls == 0: + yield "Your list appears to be empty." + return + await conn.commit() + yield f"Queued {num_urls} pages from {args.url} for Brozzler-based archival. You will be notified when it finishes. Use !status {job_id} or check {item_url(job_id)} for details." + else: + if args.redirector: + yield "Sorry, but only one URL or URL list can be specified at a time." + return + await submit_page_batch([args.url]) + await conn.commit() + yield f"Queued {args.url} for Brozzler-based archival. You will be notified when it finishes. Use !status {job_id} or check {item_url(job_id)} for details." + +@bot.add_argument("--skip", action = "store_true", dest = "skip", default = None) +@bot.add_argument("--no-skip", action = "store_false", dest = "skip", default = None) +@bot.add_argument("--accept", action = "store_true", dest = "accept", default = None) +@bot.add_argument("--reject", action = "store_false", dest = "accept", default = None) +@bot.add_argument("--add-before", default = None, type = int) +@bot.add_argument("--ensure-ruleset", default = None) +@bot.add_argument("pattern") +@bot.add_argument("job_id") +@bot.argparse("!addrule") +@bot.command({"!addrule"}) +async def addrule(self: Bot, user: User, ran, args): + settings = {} + if args.skip is not None: + settings['skip'] = args.skip + if args.accept is not None: + settings['accept'] = args.accept + if not settings: + yield "You must specify at least one setting." + return + # Ensure the regex can be compiled + db.regex.compile(args.pattern) + + async with ENGINE.begin() as conn: + queue = db.Connection(conn) + nid, nidx = await queue.create_job_rule(args.job_id, args.add_before, db.JobRule(args.pattern, settings), args.ensure_ruleset) + yield f"Created new rule at index {nidx} (new ruleset ID: {nid})." + +@bot.command({"!concurrency", "!con"}) +async def concurrency(self: Bot, user: User, ran, job_id, num): + job_id = db.parse_id(job_id) + try: + num = int(num) + assert num > 0 + except (ValueError, AssertionError): + yield "Sorry, but concurrency must be a positive integer." + return + async with ENGINE.begin() as conn: + q = sqlalchemy.update(model.jobs).where(model.jobs.c.job_id == job_id).values(concurrency = num) + await conn.execute(q) + yield f"Updated concurrency of {job_id} to {num}." + +async def generate_status_message(job: str, queue: db.Connection): + q = sqlalchemy.select(model.jobs.c.status, model.jobs.c.initial_page, model.jobs.c.note) + ts = db.parse_id_ex(job).timestamp / 1000 + date = datetime.datetime.fromtimestamp(ts, datetime.UTC) + res = await queue.conn.execute(q) + ent = res.first() + if not ent: + return f"No job with ID {repr(job)} could be found." + return f"Job {job} ({repr(ent[1])}) has status {ent[0].name} and was queued at {date.isoformat(timespec='seconds')}. See {item_url(job)} for more information. Explanation: {ent[2]}" + +@bot.command("!status") +async def status(self: Bot, user: User, ran, *jobs): + async with ENGINE.connect() as conn: + conn = await conn.execution_options(postgresql_readonly = True) + queue = db.Connection(conn) + if jobs: + for job in jobs: + yield await generate_status_message(job, queue) + else: + async with ENGINE.connect() as conn: + counts = await queue.get_job_counts() + if not counts: + yield "There aren't any queued, running, or stashed items." + for pipeline_type, count in counts.items(): + yield f"Status for {str(pipeline_type.name)}: {count} active jobs." + +@bot.command({"!explain", "!e"}) +async def explain(self: Bot, user: User, ran, id, *reason): + id = db.parse_id(id) + async with ENGINE.begin() as conn: + r = " ".join(reason) or None + q = sqlalchemy.update(model.jobs).where(model.jobs.c.job_id == id).values(note = r) + res = await conn.execute(q) + if res.rowcount: + yield f"Reason for {id} set to {r!r}." + else: + yield "No item was found." + +@bot.command("!tag") +async def tag(self: Bot, user: User, ran, command: str, pipeline_id: str, tag = None): + async with ENGINE.connect() as conn: + if command == "list": + if tag: + yield "Too many arguments for !tag list." + return + conn = await conn.execution_options(postgresql_readonly = True) + queue = db.Connection(conn) + try: + pipeline = await queue.pipeline(pipeline_id) + except db.NoSuchPipelineError: + yield f"Pipeline {pipeline_id} does not exist." + return + tags = await pipeline.get_tags() + if tags: + message = f"Pipeline {pipeline_id} has the following tags: " + message += ", ".join(tags) + else: + message = f"Pipeline {pipeline_id} has no tags." + yield message + elif not tag: + yield "A tag must be provided." + return + else: + if "@" not in user.modes: + yield "Sorry, but only operators can modify tags." + return + queue = db.Connection(conn) + try: + pipeline = await queue.pipeline(pipeline_id) + except db.NoSuchPipelineError: + yield f"Pipeline {pipeline_id} does not exist." + return + if command == "add": + await pipeline.create_tags(tag) + await conn.commit() + yield f"Added {tag} to pipeline {pipeline_id}." + elif command == "remove": + res = await pipeline.remove_tags(tag) + await conn.commit() + if res: + yield f"Removed {tag} from pipeline {pipeline_id}." + else: + yield f"Tag {tag} was not found on pipeline {pipeline_id}, no action was taken." + else: + yield "Invalid subcommand." + +@bot.command("!help") +async def help(self: Bot, user: User, ran, command = None): + yield f"Documentation can be found at {DOCUMENTATION_URL}." + +@bot.command("!page") +async def page(self: Bot, user: User, ran, page_id, action, arg = None): + page_id = db.parse_id(page_id) + async with ENGINE.connect() as conn: + queue = db.Connection(conn) + if arg: + if "@" not in user.modes: + yield "Sorry, but only operators can update page metadata." + return + if action == "status": + ns = model.PageStatus[arg.upper()] + q = sqlalchemy.update(model.pages).where(model.pages.c.page_id == page_id).values(status = ns) + await conn.execute(q) + await conn.commit() + yield f"Updated {page_id} to status {ns.name}." + elif action == "tries": + try: + nt = int(arg) + except ValueError: + yield f"Invalid integer {arg}." + return + await queue.retry_page(page_id, nt) + await conn.commit() + yield f"Cleared attempt counter and set maximum of {nt} tries for {page_id}." + return + else: + yield f"{action} is not a valid query." + return + else: + conn = await conn.execution_options(postgres_readonly = True) + if action == "status": + q = sqlalchemy.select(model.pages.c.status).where(model.pages.c.page_id == page_id) + res = await conn.scalar(q) + if not res: + yield f"Page {page_id} does not exist." + return + yield f"Page {page_id} has status {res.name}." + return + elif action == "tries": + q = sqlalchemy.select(model.pages.c.attempts, model.pages.c.attempts_remaining).where(model.pages.c.page_id == page_id) + result = await conn.execute(q) + res = result.first() + if not res: + yield f"Page {page_id} does not exist." + return + yield f"Since last reset, page {page_id} has been tried {res[0]} out of {res[0] + res[1]} allowed attempts." + return + else: + yield f"{action} is not a valid query." + return + +RED = Colour.make_colour(Colour.RED) +@bot.exception_handler +async def handler(self: Bot, command, user: User, e): + if isinstance(e, db.InvalidIdError): + return f"{user.nick}: {RED}Invalid UUID." + elif isinstance(e, ValidationError): + return f"{user.nick}: {RED}Failed to validate URL {repr(e.url)}, cowardly bailing out.{Format.RESET} (Ops may use --skip-url-validation to bypass this.)" + elif isinstance(e, db.SerializationFailure): + return f"{user.nick}: {RED}Serialization failure! Please try again." + else: + print("Exception occurred!") + traceback.print_exc() + return f"{user.nick}: {RED}An error occurred while processing the command." + +async def main(): + global ENGINE, AIOHTTP_SESSION + ENGINE = await db.create_engine() + AIOHTTP_SESSION = aiohttp.ClientSession() + await bot.run_forever() + +if __name__ == "__main__": + asyncio.run(main()) + diff --git a/tracker/irc/app.py b/tracker/irc/app.py deleted file mode 100644 index d0dd087..0000000 --- a/tracker/irc/app.py +++ /dev/null @@ -1,196 +0,0 @@ -import os -import asyncio - -import aiohttp -import validators - -from bot2h import Bot, Colour, Format, User -from rue import Queue, Status - -H2IBOT_GET_URL = os.environ['H2IBOT_GET_URL'] -H2IBOT_POST_URL = os.environ['H2IBOT_POST_URL'] -TRACKER_BASE_URL = os.environ['TRACKER_BASE_URL'].rstrip("/") -DOCUMENTATION_URL = os.environ['DOCUMENTATION_URL'] -MNBOT_HEADER = "//! mnbot v1" # DO NOT ADD \n or \r\n here. - -def is_mnbot_js(payload): - """ - Returns true if the payload is a mnbot CJS header. - """ - if not payload.startswith(MNBOT_HEADER): return False - def validate_suffix(payload, suffix): - return payload[len(MNBOT_HEADER):len(MNBOT_HEADER)+len(suffix)] == suffix - return validate_suffix(payload, "\n") or validate_suffix(payload, "\r\n") - -def item_url(id: str | None): - if not id: - return "" - return f"{TRACKER_BASE_URL}/item/{id}" - -bot = Bot(H2IBOT_GET_URL, H2IBOT_POST_URL, max_coros = 1) - -QUEUE = Queue("mnbot") -asyncio.run(QUEUE.check()) -print("setup complete") - -AIOHTTP_SESSION = None - -PRESET_USER_AGENTS = { - "curl": "curl/7.88.1", - "archivebot": ( - "ArchiveTeam ArchiveBot/20240923.203d40a (wpull 2.0.3) and not Mozilla/5.0 " - "(Windows NT 6.1; WOW64) AppleWebKit/537.36 (KHTML, like Gecko) " - "Chrome/42.0.2311.90 Safari/537.36" - ), - "googlebot1": ( - "Mozilla/5.0 (compatible; Googlebot/2.1; +http://www.google.com/bot.html)" - ), -} - -@bot.add_argument( - "--user-agent", "-u", - choices = ("default", "stealth", "curl", "archivebot", "minimal", "googlebot1", "googlebot"), - default = "default" -) -@bot.add_argument("--explanation", "--explain", "-e") -@bot.add_argument("--custom-js") -@bot.add_argument("--skip-url-validation", action = "store_true") -@bot.add_argument("--nice", "-n", type = int, default = 0) -@bot.add_argument("url") -@bot.argparse("!brozzle") -@bot.command({"!b", "!brozzle"}, required_modes="+@") -async def brozzle(self: Bot, user: User, ran, args): - global AIOHTTP_SESSION - custom_js = None - if args.nice < -10: - if "@" not in user.modes: - yield "Sorry, but only operators can queue with a niceness lower than -10." - return - if args.custom_js: - if "@" not in user.modes: - yield "Sorry, but only operators can use custom JavaScript." - return - try: - if not AIOHTTP_SESSION: - AIOHTTP_SESSION = aiohttp.ClientSession() - async with AIOHTTP_SESSION.get(args.custom_js) as resp: - if resp.status != 200: - yield f"Error: Got status {resp.status} (expected 200) when fetching custom JS." - return - custom_js = await resp.text() - if not is_mnbot_js(custom_js): - yield "Error: Custom JS must start with a valid mnbot header." - return - except Exception: - yield "An error occured when retrieving custom JS." - return - if args.skip_url_validation: - if "@" not in user.modes: - yield "Sorry, but only operators can bypass URL validation." - return - else: - result = validators.url(args.url, strict_query = False, private = False) - if result is not True: - yield "Failed to validate your URL, cowardly bailing out. (Ops may use --skip-url-validation to bypass this.)" - return - ua = PRESET_USER_AGENTS.get(args.user_agent, args.user_agent) - if ua := PRESET_USER_AGENTS.get(args.user_agent): - ua = "$" + ua - else: - ua = args.user_agent - ent = await QUEUE.new( - args.url, - "brozzler", - user.nick, - explanation = args.explanation, - metadata = {"ua": ua, "custom_js": custom_js}, - priority = args.nice - ) - yield f"Queued {args.url} for Brozzler-based archival. You will be notified when it finishes. Use !status {ent.id} or check {item_url(ent.id)} for details." - -async def generate_status_message(job: str): - ent = await QUEUE.get(job) - if not ent: - return f"No job with ID {repr(job)} could be found. Note: Items cannot currently be looked up by URL." - return f"Job {ent.id} ({repr(ent.item)}) has status {ent.status.upper()} and was queued at {ent.queued_at.isoformat(timespec='seconds')}. See {item_url(ent.id)} for more information. Explanation: {ent.explanation}" - -@bot.command("!status") -async def status(self: Bot, user: User, ran, *jobs): - if jobs: - for job in jobs: - yield await generate_status_message(job) - else: - c = await QUEUE.counts() - if not c.counts: - yield "There aren't any queued, running, or stashed items." - for pipeline_type, counts in c.counts.items(): - yield f"Status for {pipeline_type.upper()}: {counts['todo']} pending items, {counts['claimed']} items in progress, {counts['stashed']} items stashed away." - if limbo := c.limbo: - yield f"{limbo} items are possibly in limbo and may need admin intervention." - -@bot.command("!limbo") -async def limbo(self: Bot, user: User, ran): - jobs = await QUEUE.get_limbo() - yield f"Jobs in limbo: {repr(jobs)}" - -@bot.command({"!w", "!whereis"}) -async def whereis(self: Bot, user: User, ran, id: str): - job = await QUEUE.get(id) - if not job: - yield f"Job {id} does not exist." - elif job.status != Status.CLAIMED: - yield f"Job {job.id} is not currently claimed." - else: - yield f"Job {job.id} is currently claimed by pipeline {job.pipeline_type}." - -@bot.command({"!explain", "!e"}) -async def explain(self: Bot, user: User, ran, id: str, *reason): - job = await QUEUE.get(id) - if not job: - yield f"Job {id} does not exist." - else: - r = " ".join(reason) - nj = await QUEUE.change_explanation(job, r) - yield f"Reason for {job.id} set to {nj.explanation!r}." - -@bot.command("!!reclaim", required_modes="@") -async def abandon(self: Bot, user: User, ran, id: str, *reason): - item = await QUEUE.get(id) - if not item: - yield f"Job {id} does not exist." - return - if item.status != Status.CLAIMED: - yield f"Job {id} is not claimed." - return - if not reason: - yield "You must provide a reason." - return - reason = f"Manual reclaim by {user.nick}: {' '.join(reason)}" - yield f"Explanation: {reason!r}" - new_item = await QUEUE.fail(item, reason, item.current_attempt(), is_poke = True) - if new_item.status == Status.ERROR: - yield f"Max tries reached for {new_item.id}, it has been moved to ERROR." - elif new_item.status == Status.TODO: - yield f"{new_item.id} has been moved to todo. Tries: {len(new_item.attempts)}" - else: - yield f"Unexpected status {new_item.status} for item {new_item.id}. This is probably bad." - -@bot.command("!!dripfeed", required_modes = "@") -async def dripfeed(self: Bot, user: User, ran: str, stash: str, raw_concurrency: str): - try: - concurrency: int = int(raw_concurrency) - except ValueError: - yield "Invalid integer." - return - await QUEUE.set_dripfeed_behaviour(stash, concurrency) - if concurrency > 0: - yield f"Enabled dripfeeding of {stash} with a rate of {concurrency} concurrent." - else: - yield f"Disabled dripfeeding of {stash}." - -@bot.command("!help") -async def help(self: Bot, user: User, ran, command = None): - yield f"Documentation can be found at {DOCUMENTATION_URL}." - -asyncio.run(bot.run_forever()) - diff --git a/tracker/scripts/__main__.py b/tracker/scripts/__main__.py new file mode 100644 index 0000000..09a1080 --- /dev/null +++ b/tracker/scripts/__main__.py @@ -0,0 +1,48 @@ +import argparse +import asyncio +import sys + +from ..common import model +from ..common import db + +import secrets +import string + +parser = argparse.ArgumentParser() +parser.add_argument("--uri", default = None, help = "URI to a postgres database; default is the value of MNBOT_DATABASE_URI") + +subparsers = parser.add_subparsers(required = True) + +async def create(args): + if args.tries > 32767 or args.tries < 0: + raise ValueError("Invalid value of tries, must fit within the positive range of a SMALLINT") + engine = await db.create_engine(args.uri, check_version = False) + async with engine.connect() as conn: + await conn.run_sync(model.metadata_obj.create_all) + await conn.execute(model.options.insert().values(key = "version", value = str(model.SCHEMA_VERSION))) + await conn.execute(model.options.insert().values(key = "tries", value = str(args.tries))) + await conn.commit() + await engine.dispose() + +subparser_create = subparsers.add_parser("create") +subparser_create.add_argument("--tries", type = int) +subparser_create.set_defaults(func = create) + +async def add_pipeline(args): + engine = await db.create_engine(args.uri, check_version = True) + async with engine.begin() as conn: + queue = db.Connection(conn) + alphabet = string.ascii_letters + string.digits + password = "".join(secrets.choice(alphabet) for i in range(32)) + await queue.create_pipeline(args.id, args.matchonly, password) + print("Created", "matchonly" if args.matchonly else "regular", "pipeline", args.id, "with password", password) + await engine.dispose() + +subparser_add_pipeline = subparsers.add_parser("add_pipeline") +subparser_add_pipeline.set_defaults(func = add_pipeline) +subparser_add_pipeline.add_argument("--matchonly", action = "store_true", default = False) +subparser_add_pipeline.add_argument("id") + +if __name__ == "__main__": + args = parser.parse_args() + asyncio.run(args.func(args)) diff --git a/tracker/server/__main__.py b/tracker/server/__main__.py new file mode 100644 index 0000000..619adab --- /dev/null +++ b/tracker/server/__main__.py @@ -0,0 +1,203 @@ +import asyncio +import base64 +import collections +import time +import os +import json +import typing +import logging + +from ..common import model, db + +from websockets.asyncio.server import ServerConnection, basic_auth, serve +from bot2h import Format, SendOnlyBot, Colour + +import sqlalchemy + +logging.basicConfig(level=logging.INFO) + +INFO_URL = os.environ['INFO_URL'] +TRACKER_BASE_URL = os.environ['TRACKER_BASE_URL'].rstrip("/") + +def item_url(id: model.UUID): + return f"{TRACKER_BASE_URL}/job/{id}" + +bot = SendOnlyBot(os.environ['H2IBOT_POST_URL']) + +def notify_user(job_id: model.UUID, initial_item: str, author: str, message: str): + url = item_url(job_id) + return f"{author}: Your job {job_id} for {initial_item} {message} See {url} for more information." + +HANDLERS = {} + +def handler(msg_type: str): + def decorator(f: typing.Callable): + assert msg_type not in HANDLERS + HANDLERS[msg_type] = f + return f + return decorator + +Response: typing.TypeAlias = tuple[int, typing.Optional[dict[str, typing.Any]]] +HandlerContext = collections.namedtuple("HandlerContext", ["message", "pipeline", "version"]) + +@handler("System:ping") +async def pong(ctx: HandlerContext) -> Response: + return 204, None + +@handler("Item:claim") +async def get(ctx: HandlerContext, *, slot) -> Response: + pipeline: db.Pipeline = ctx.pipeline + notify_message = None + + async with pipeline.parent.conn.begin(): + try: + claim = await pipeline.find_claim_page(ctx.version, slot, model.JobType.BROZZLER) + except db.JobExhausted as e: + claim = None + ns = await pipeline.parent.update_job_status(e.job_id) + q = sqlalchemy.select(model.jobs.c.initial_page, model.jobs.c.created_by).where(model.jobs.c.job_id == e.job_id) + row = (await pipeline.conn.execute(q)).one() + if ns == model.JobStatus.DONE: + notify_message = notify_user(e.job_id, row[0], row[1], "has finished.") + + if notify_message: + await bot.send_message(notify_message) + + if claim: + payload = { + "item": claim.as_json_friendly_dict(), + "info_url": INFO_URL + } + else: + payload = { + "item": None, + "message": "No items found." + } + return 200, payload + +RED = Colour.make_colour(Colour.RED, escape = False) +RESET = Format.RESET +MONO = Format.MONOSPACE +@handler("Item:fail") +async def fail(ctx: HandlerContext, *, attempt_id, message, fatal) -> Response: + attempt_id = db.parse_id(attempt_id) + pipeline: db.Pipeline = ctx.pipeline + notify_message = None + + # Must commit before sending any notification message, otherwise + # serialization failures may cause excess or incorrect notifications. + async with pipeline.parent.conn.begin(): + tries_remaining = await pipeline.fail_attempt(attempt_id, message, fatal) + + if tries_remaining <= 0: + job_id, (initial_item, author) = await pipeline.parent.attempt_id_to_job_id(attempt_id, [model.jobs.c.initial_page, model.jobs.c.created_by]) + new_status = await pipeline.parent.update_job_status(job_id) + if new_status == model.JobStatus.DONE: + if await pipeline.parent.is_single_job(job_id): + summary = message.split("\n", 1)[0].strip() + notify_message = notify_user(job_id, initial_item, author, f"has {RED}failed{RESET} (last error: {MONO}{summary}{RESET}).") + else: + notify_message = notify_user(job_id, initial_item, author, "has finished.") + + if notify_message: + await bot.send_message(notify_message) + return 204, None + +@handler("Item:store") +async def store(ctx: HandlerContext, *, result_id, attempt_id, type, payload): + result_id = db.parse_id(result_id) + attempt_id = db.parse_id(attempt_id) + pipeline: db.Pipeline = ctx.pipeline + if type == "cjs_screenshot": + type = "custom_js_screenshot" + await pipeline.create_result(attempt_id, result_id, model.ResultType[type.upper()], payload) + return 201, {"new_id": str(result_id)} + +@handler("Item:finish") +async def finish(ctx: HandlerContext, *, attempt_id) -> Response: + attempt_id = db.parse_id(attempt_id) + pipeline: db.Pipeline = ctx.pipeline + notify_message = None + + async with pipeline.parent.conn.begin(): + await pipeline.finish_attempt(attempt_id) + job_id, (initial_item, author) = await pipeline.parent.attempt_id_to_job_id(attempt_id, [model.jobs.c.initial_page, model.jobs.c.created_by]) + new_status = await pipeline.parent.update_job_status(job_id) + if new_status == model.JobStatus.DONE: + notify_message = notify_user(job_id, initial_item, author, "has finished.") + + if notify_message: + await bot.send_message(notify_message) + return 204, None + +async def handle_connection(websocket: ServerConnection): + initial = json.loads(await websocket.recv()) + version = initial['v'] + protocol = initial.get("p", 1) + if protocol != 2: + await websocket.close(reason = "Protocol mismatch! Please update your client") + slots = list(range(0, initial['num_slots'])) + async with ENGINE.connect() as conn: + queue = db.Connection(conn) + pipeline = await queue.pipeline(websocket.username, *slots) + await conn.commit() + async for message in websocket: + start_time = time.time() + try: + message = json.loads(message) + type = message['type'] + seq = message['seq'] + except KeyError: + await websocket.send(json.dumps({"status": 400, "message": "Missing message data"})) + continue + except json.JSONDecodeError: + await websocket.send(json.dumps({"status": 400, "message": "Invalid JSON"})) + continue + if callback := HANDLERS.get(type): + try: + ctx = HandlerContext(message = message, pipeline = pipeline, version = version) + # TODO: Detect when payload params don't match up + # and return 400 + status, payload = await callback(ctx, **message.get("request") or {}) + reply = {"status": status, "payload": payload, "seq": seq} + if conn.in_transaction(): + await conn.commit() + except Exception: + logging.exception(f"Exception occured while handling message {message}") + reply = {"status": 500, "message": "An exception occured.", "seq": seq} + await conn.rollback() + elapsed = round(time.time() - start_time, 1) + logging.info(f"Handled {repr(type)} message from {websocket.username} with code {reply['status']} (in {elapsed}s)") + await websocket.send(json.dumps(reply)) + else: + await websocket.send(json.dumps({"status": 404, "message": f"Request type {type} does not exist", "seq": seq})) + +async def authenticate(username, key): + async with ENGINE.connect() as conn: + conn = await conn.execution_options(postgresql_readonly = True) + queue = db.Connection(conn) + pipeline = await queue.pipeline(username) + try: + await pipeline.authenticate(key) + except db.AuthenticationFailure: + return False + else: + return True + +authenticator = basic_auth( + realm = "mnbot item server", + check_credentials = authenticate +) + +async def main(): + global ENGINE + ENGINE = await db.create_engine() + async with serve( + handle_connection, + "0.0.0.0", 8897, + process_request = authenticator, + max_size=2**25 + ) as server: + await server.serve_forever() + +asyncio.run(main()) diff --git a/tracker/server/app.py b/tracker/server/app.py deleted file mode 100644 index 9565f39..0000000 --- a/tracker/server/app.py +++ /dev/null @@ -1,154 +0,0 @@ -import asyncio -import base64 -import collections -import os -import json -import hmac -import typing -import logging - -from rethinkdb import r -from websockets.asyncio.server import ServerConnection, basic_auth, serve -from rue import Entry, Queue, Status, RetryBehaviour -from bot2h import Format, SendOnlyBot, Colour - -logging.basicConfig(level=logging.INFO) - -INFO_URL = os.environ['INFO_URL'] -TRACKER_BASE_URL = os.environ['TRACKER_BASE_URL'].rstrip("/") - -def item_url(id: str | None): - if not id: - return "" - return f"{TRACKER_BASE_URL}/item/{id}" - -r.set_loop_type("asyncio") -bot = SendOnlyBot(os.environ['H2IBOT_POST_URL']) - -QUEUE = Queue("mnbot") -asyncio.run(QUEUE.check()) - -async def notify_user(item: Entry, message: str): - url = item_url(item.id) - await bot.send_message(f"{item.queued_by}: Your job {item.id} for {item.item} {message} See {url} for more information.") - -HANDLERS = {} - -def handler(msg_type: str): - def decorator(f: typing.Callable): - assert msg_type not in HANDLERS - HANDLERS[msg_type] = f - return f - return decorator - -Response: typing.TypeAlias = tuple[int, typing.Optional[dict[str, typing.Any]]] -HandlerContext = collections.namedtuple("HandlerContext", ["message", "username", "version"]) - -@handler("System:ping") -async def pong(ctx: HandlerContext) -> Response: - await QUEUE.heartbeat(ctx.username) - return 204, None - -@handler("Item:claim") -async def get(ctx: HandlerContext, *, pipeline_type) -> Response: - item = await QUEUE.claim(ctx.username, pipeline_type, ctx.version) - if item: - payload = { - "item": item.as_json_friendly_dict() | {"_current_attempt": item.current_attempt()}, - "info_url": INFO_URL - } - else: - payload = { - "item": None, - "message": "No items found." - } - return 200, payload - -RED = Colour.make_colour(Colour.RED, escape = False) -RESET = Format.RESET -@handler("Item:fail") -async def fail(ctx: HandlerContext, *, id, message, attempt, fatal) -> Response: - item = await QUEUE.get(id) - if not item: - raise Exception("item does not exist") - behaviour = RetryBehaviour.NEVER if fatal else RetryBehaviour.DEFAULT - new_item = await QUEUE.fail(item, message, attempt, behaviour) - if new_item.status == Status.ERROR: - await notify_user(new_item, f"has {RED}failed{RESET}.") - return 204, None - -@handler("Item:store") -async def store(ctx: HandlerContext, *, id, result, attempt, result_type, decode_fields = None): - item = await QUEUE.get(id) - if not item: - raise Exception("item does not exist") - if decode_fields: - for field in decode_fields: - result[field] = base64.b85decode(result[field]) - new_item = await QUEUE.store_result(item, attempt, result, result_type) - return 201, {"new_id": new_item} - -@handler("Item:finish") -async def finish(ctx: HandlerContext, *, id) -> Response: - item = await QUEUE.get(id) - if not item: - raise Exception("item does not exist") - new_item = await QUEUE.finish(item) - await notify_user(new_item, "has finished.") - return 204, None - -async def handle_connection(websocket: ServerConnection): - initial = json.loads(await websocket.recv()) - version = initial['v'] - async for message in websocket: - try: - message = json.loads(message) - type = message['type'] - seq = message['seq'] - except KeyError: - await websocket.send(json.dumps({"status": 400, "message": "Missing message data"})) - continue - except json.JSONDecodeError: - await websocket.send(json.dumps({"status": 400, "message": "Invalid JSON"})) - continue - if callback := HANDLERS.get(type): - try: - ctx = HandlerContext(message = message, username = websocket.username, version = version) - # TODO: Detect when payload params don't match up - # and return 400 - status, payload = await callback(ctx, **message.get("request") or {}) - reply = {"status": status, "payload": payload, "seq": seq} - except Exception: - logging.exception(f"Exception occured while handling message {message}") - reply = {"status": 500, "message": "An exception occured.", "seq": seq} - logging.info(f"Handled {repr(type)} message from {websocket.username} with code {reply['status']}") - await websocket.send(json.dumps(reply)) - else: - await websocket.send(json.dumps({"status": 404, "message": f"Request type {type} does not exist", "seq": seq})) - -async def authenticate(username, key): - conn = await r.connect(host = os.getenv("RUE_DB_HOST", "localhost")) - try: - expected_key = await r.db("mnbot").table("server_secrets").get(username).run(conn) - return expected_key and hmac.compare_digest(key, expected_key['value']) - finally: - try: - await conn.close() - except Exception: - pass - -authenticator = basic_auth( - realm = "mnbot item server", - check_credentials = authenticate -) - -async def main(): - async with serve( - handle_connection, - "0.0.0.0", 8897, - process_request = authenticator, - max_size=2**25 - ) as server: - await server.serve_forever() - -asyncio.run(main()) diff --git a/tracker_dc/docker-compose.yml b/tracker_dc/docker-compose.yml index 6056c62..6983e15 100644 --- a/tracker_dc/docker-compose.yml +++ b/tracker_dc/docker-compose.yml @@ -25,6 +25,7 @@ services: RUE_DB_HOST: ${RUE_DB_HOST:-rethinkdb} extra_hosts: - ${EXTRA_HOST:-aaa.internal:0.0.0.0} + network_mode: host tracker: build: context: ../tracker/server @@ -43,6 +44,7 @@ services: - ${EXTRA_HOST:-aaa.internal:0.0.0.0} ports: - 8897:8897 + network_mode: host bot: build: context: ../tracker/irc @@ -59,6 +61,7 @@ services: condition: service_completed_successfully extra_hosts: - ${EXTRA_HOST:-aaa.internal:0.0.0.0} + network_mode: host dashboard: build: context: ../tracker/dashboard @@ -74,3 +77,4 @@ services: - ${EXTRA_HOST:-aaa.internal:0.0.0.0} ports: - 8898:8898 + network_mode: host