From 5b301b71f09aad4d2103ce1a7a06de22d9198ac0 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Wed, 27 May 2026 23:06:03 +0530 Subject: [PATCH 1/8] feat(assessment): Implement L1 pipeline with topic relevance and duplicate detection - Added L1 pipeline orchestrator to run topic relevance and duplicate detection filters in series. - Introduced duplicate detection logic to filter out vague submissions and check for duplicates against a corpus. - Created topic relevance filter to assess the relevance of submissions based on user-defined criteria. - Integrated L1 results handling in the assessment service, allowing for conditional processing based on L1 outcomes. - Updated export functionality to include L1 results in the output, enhancing the assessment reporting capabilities. - Added Celery task for executing the L1 pipeline and managing assessment run statuses. --- .../064_add_l1_columns_to_assessment_run.py | 61 ++++ backend/app/api/routes/assessment/runs.py | 3 + backend/app/celery/tasks/job_execution.py | 24 ++ backend/app/core/config.py | 7 + backend/app/crud/assessment/__init__.py | 2 + backend/app/crud/assessment/batch.py | 19 +- backend/app/crud/assessment/core.py | 50 ++- backend/app/crud/assessment/cron.py | 2 +- backend/app/models/assessment.py | 42 ++- .../app/services/assessment/l1/__init__.py | 3 + .../assessment/l1/duplicate_detection.py | 214 +++++++++++++ .../app/services/assessment/l1/pipeline.py | 225 +++++++++++++ .../services/assessment/l1/topic_relevance.py | 93 ++++++ backend/app/services/assessment/service.py | 98 +++--- backend/app/services/assessment/tasks.py | 196 ++++++++++++ .../app/services/assessment/utils/export.py | 299 +++++++++++++----- 16 files changed, 1182 insertions(+), 156 deletions(-) create mode 100644 backend/app/alembic/versions/064_add_l1_columns_to_assessment_run.py create mode 100644 backend/app/services/assessment/l1/__init__.py create mode 100644 backend/app/services/assessment/l1/duplicate_detection.py create mode 100644 backend/app/services/assessment/l1/pipeline.py create mode 100644 backend/app/services/assessment/l1/topic_relevance.py create mode 100644 backend/app/services/assessment/tasks.py diff --git a/backend/app/alembic/versions/064_add_l1_columns_to_assessment_run.py b/backend/app/alembic/versions/064_add_l1_columns_to_assessment_run.py new file mode 100644 index 000000000..bce33e6cd --- /dev/null +++ b/backend/app/alembic/versions/064_add_l1_columns_to_assessment_run.py @@ -0,0 +1,61 @@ +"""Add L1 pipeline columns to assessment_run + +Revision ID: 064 +Revises: 063 +Create Date: 2026-05-27 00:00:00.000000 + +""" + +import sqlalchemy as sa +from alembic import op + +revision = "064" +down_revision = "063" +branch_labels = None +depends_on = None + + +def upgrade() -> None: + op.add_column( + "assessment_run", + sa.Column( + "l1_object_store_url", + sa.String(), + nullable=True, + comment="S3 URL of stored L1 filter results JSON", + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "l1_total_rows", + sa.Integer(), + nullable=True, + comment="Total rows fed into L1 pipeline", + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "l1_total_passed", + sa.Integer(), + nullable=True, + comment="Rows that passed topic relevance and went to L2", + ), + ) + op.add_column( + "assessment_run", + sa.Column( + "l1_total_rejected", + sa.Integer(), + nullable=True, + comment="Rows rejected by topic relevance, stopped at L1", + ), + ) + + +def downgrade() -> None: + op.drop_column("assessment_run", "l1_total_rejected") + op.drop_column("assessment_run", "l1_total_passed") + op.drop_column("assessment_run", "l1_total_rows") + op.drop_column("assessment_run", "l1_object_store_url") diff --git a/backend/app/api/routes/assessment/runs.py b/backend/app/api/routes/assessment/runs.py index 18a9be60e..18398eeb0 100644 --- a/backend/app/api/routes/assessment/runs.py +++ b/backend/app/api/routes/assessment/runs.py @@ -65,6 +65,9 @@ def _build_run_public( total_items=run.total_items, error_message=run.error_message, input=run.input, + l1_total_rows=run.l1_total_rows, + l1_total_passed=run.l1_total_passed, + l1_total_rejected=run.l1_total_rejected, inserted_at=run.inserted_at, updated_at=run.updated_at, ) diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index adadf1c9c..ec7ad1bd0 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -232,6 +232,30 @@ def run_tts_batch_submission( ) +@celery_app.task( + bind=True, queue="low_priority", priority=1, soft_time_limit=1800, time_limit=2100 +) +def run_assessment_run( + self, + run_id: int, + organization_id: int, + project_id: int, + trace_id: str, + **kwargs, +): + from app.services.assessment.tasks import execute_assessment_run + + _set_trace(trace_id) + return _run_with_otel_parent( + self, + lambda: execute_assessment_run( + run_id=run_id, + organization_id=organization_id, + project_id=project_id, + ), + ) + + @celery_app.task(bind=True, queue="low_priority", priority=1) @gevent_timeout(settings.CELERY_TASK_SOFT_TIME_LIMIT, "run_tts_result_processing") def run_tts_result_processing( diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 720846eb9..60504147b 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -171,6 +171,13 @@ def AWS_S3_BUCKET(self) -> str: DOC_TRANSFORMATION_PENDING_THRESHOLD_MINUTES: int = 30 PENDING_JOB_QUERY_TIMEOUT_MS: int = 1000 + # Assessment + ASSESSMENT_L1_GEMINI_MODEL: str = "gemini-3.1-flash-lite" + ASSESSMENT_L1_CONCURRENT_WORKERS: int = 8 + ASSESSMENT_L1_DUPLICATE_STORE_NAME: str = ( + "fileSearchStores/inquilabcorpus-782mxjcwisaz" + ) + @computed_field # type: ignore[prop-decorator] @property def COMPUTED_CELERY_WORKER_CONCURRENCY(self) -> int: diff --git a/backend/app/crud/assessment/__init__.py b/backend/app/crud/assessment/__init__.py index cd71bff91..8e5c9984d 100644 --- a/backend/app/crud/assessment/__init__.py +++ b/backend/app/crud/assessment/__init__.py @@ -13,6 +13,7 @@ list_assessment_runs, list_assessments, recompute_assessment_status, + update_assessment_run_l1_stats, update_assessment_run_status, ) from app.crud.assessment.dataset import ( @@ -42,5 +43,6 @@ "list_assessment_datasets", "list_assessments", "recompute_assessment_status", + "update_assessment_run_l1_stats", "update_assessment_run_status", ] diff --git a/backend/app/crud/assessment/batch.py b/backend/app/crud/assessment/batch.py index b45603853..e5e52daec 100644 --- a/backend/app/crud/assessment/batch.py +++ b/backend/app/crud/assessment/batch.py @@ -161,6 +161,7 @@ def build_openai_jsonl( attachments: list[AssessmentAttachment], prompt_template: str | None, openai_params: dict, + row_indices: list[int] | None = None, ) -> list[dict[str, Any]]: """Build OpenAI batch JSONL data from dataset rows. @@ -174,7 +175,8 @@ def build_openai_jsonl( """ jsonl_data = [] - for idx, row in enumerate(rows): + for i, row in enumerate(rows): + idx = row_indices[i] if row_indices is not None else i # Build input array input_parts: list[dict[str, Any]] = [] @@ -219,6 +221,7 @@ def build_google_jsonl( attachments: list[AssessmentAttachment], prompt_template: str | None, google_params: dict, + row_indices: list[int] | None = None, ) -> list[dict[str, Any]]: """Build Google (Gemini) batch JSONL data from dataset rows. @@ -230,7 +233,8 @@ def build_google_jsonl( """ jsonl_data = [] - for idx, row in enumerate(rows): + for i, row in enumerate(rows): + idx = row_indices[i] if row_indices is not None else i parts: list[dict[str, Any]] = [] # Text prompt @@ -349,6 +353,8 @@ def submit_assessment_batch( assessment_input: dict[str, Any], organization_id: int, project_id: int, + preloaded_rows: list[dict[str, str]] | None = None, + row_indices: list[int] | None = None, ) -> BatchJob: """Build JSONL and submit a batch for one assessment run. @@ -371,8 +377,11 @@ def submit_assessment_batch( output_schema = assessment_input.get("output_schema") attachments = [AssessmentAttachment(**a) for a in attachments_raw] - # Load dataset rows - rows = _load_dataset_rows(session, dataset) + # Use preloaded rows (post-L1 filtered) if provided, else load from dataset. + if preloaded_rows is not None: + rows = preloaded_rows + else: + rows = _load_dataset_rows(session, dataset) if not rows: raise ValueError(f"Dataset {dataset.id} has no rows") @@ -412,6 +421,7 @@ def submit_assessment_batch( attachments=attachments, prompt_template=prompt_template, openai_params=mapped_params, + row_indices=row_indices, ) # Get OpenAI client and submit @@ -452,6 +462,7 @@ def submit_assessment_batch( attachments=attachments, prompt_template=prompt_template, google_params=mapped_params, + row_indices=row_indices, ) # Get Gemini client and submit diff --git a/backend/app/crud/assessment/core.py b/backend/app/crud/assessment/core.py index c91626660..6e2c8b2f7 100644 --- a/backend/app/crud/assessment/core.py +++ b/backend/app/crud/assessment/core.py @@ -223,16 +223,53 @@ def update_assessment_run_status( return run +def update_assessment_run_l1_stats( + session: Session, + run: AssessmentRun, + l1_object_store_url: str | None = None, + l1_total_rows: int | None = None, + l1_total_passed: int | None = None, + l1_total_rejected: int | None = None, +) -> AssessmentRun: + """Persist L1 result stats (rows/passed/rejected + S3 URL) on a run.""" + run.updated_at = now() + + if l1_object_store_url is not None: + run.l1_object_store_url = l1_object_store_url + if l1_total_rows is not None: + run.l1_total_rows = l1_total_rows + if l1_total_passed is not None: + run.l1_total_passed = l1_total_passed + if l1_total_rejected is not None: + run.l1_total_rejected = l1_total_rejected + + session.add(run) + try: + session.commit() + session.refresh(run) + except Exception as e: + session.rollback() + logger.error(f"[update_assessment_run_l1_stats] Failed: {e}", exc_info=True) + raise + + return run + + +_ACTIVE_RUN_STATUSES = frozenset( + {"l1_processing", "l2_processing", "processing", "in_progress"} +) +_FAILED_RUN_STATUSES = frozenset({"failed", "l1_failed"}) +_COMPLETED_RUN_STATUSES = frozenset({"completed", "completed_with_errors"}) + + def compute_run_counts(runs: list[AssessmentRun]) -> AssessmentRunCounts: """Aggregate child run statuses into counters.""" return AssessmentRunCounts( total=len(runs), pending=sum(1 for run in runs if run.status == "pending"), - processing=sum( - 1 for run in runs if run.status in {"processing", "in_progress"} - ), - completed=sum(1 for run in runs if run.status == "completed"), - failed=sum(1 for run in runs if run.status == "failed"), + processing=sum(1 for run in runs if run.status in _ACTIVE_RUN_STATUSES), + completed=sum(1 for run in runs if run.status in _COMPLETED_RUN_STATUSES), + failed=sum(1 for run in runs if run.status in _FAILED_RUN_STATUSES), ) @@ -267,6 +304,9 @@ def build_run_stats(runs: list[AssessmentRun]) -> list[AssessmentRunStat]: total_items=run.total_items, error_message=run.error_message, updated_at=run.updated_at, + l1_total_rows=run.l1_total_rows, + l1_total_passed=run.l1_total_passed, + l1_total_rejected=run.l1_total_rejected, ) for run in runs ] diff --git a/backend/app/crud/assessment/cron.py b/backend/app/crud/assessment/cron.py index c69b3157e..6cb76b1f5 100644 --- a/backend/app/crud/assessment/cron.py +++ b/backend/app/crud/assessment/cron.py @@ -78,7 +78,7 @@ async def poll_all_pending_assessment_evaluations( runs = get_assessment_runs_for_assessment( session=session, assessment_id=assessment.id ) - active_runs = [run for run in runs if run.status == "processing"] + active_runs = [run for run in runs if run.status == "l2_processing"] if not active_runs: refreshed = recompute_assessment_status( diff --git a/backend/app/models/assessment.py b/backend/app/models/assessment.py index 25ac0f00e..0dd0a96d1 100644 --- a/backend/app/models/assessment.py +++ b/backend/app/models/assessment.py @@ -108,7 +108,10 @@ class AssessmentRun(SQLModel, table=True): status: str = SQLField( default="pending", sa_column_kwargs={ - "comment": "Run status: pending, processing, completed, failed" + "comment": ( + "Unified pipeline status: pending, l1_processing, l1_failed, " + "l2_processing, completed, completed_with_errors, failed" + ) }, ) batch_job_id: int | None = SQLField( @@ -136,7 +139,27 @@ class AssessmentRun(SQLModel, table=True): object_store_url: str | None = SQLField( default=None, nullable=True, - sa_column_kwargs={"comment": "S3 URL of processed batch results"}, + sa_column_kwargs={"comment": "S3 URL of processed L2 batch results"}, + ) + l1_object_store_url: str | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={"comment": "S3 URL of stored L1 filter results JSON"}, + ) + l1_total_rows: int | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Total rows fed into L1 pipeline"}, + ) + l1_total_passed: int | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Rows that passed topic relevance and went to L2"}, + ) + l1_total_rejected: int | None = SQLField( + default=None, + nullable=True, + sa_column_kwargs={"comment": "Rows rejected by topic relevance, stopped at L1"}, ) error_message: str | None = SQLField( default=None, @@ -185,6 +208,9 @@ class AssessmentRunStat(BaseModel): total_items: int error_message: str | None = None updated_at: datetime | None = None + l1_total_rows: int | None = None + l1_total_passed: int | None = None + l1_total_rejected: int | None = None class AssessmentPublic(BaseModel): @@ -224,6 +250,9 @@ class AssessmentRunPublic(BaseModel): "text_columns, attachments, output_schema" ), ) + l1_total_rows: int | None = None + l1_total_passed: int | None = None + l1_total_rejected: int | None = None inserted_at: datetime updated_at: datetime @@ -286,6 +315,13 @@ class AssessmentCreate(BaseModel): configs: list[AssessmentConfigRef] = Field( ..., min_length=1, max_length=4, description="Config versions to run" ) + l1_config: dict[str, Any] | None = Field( + None, + description=( + "L1 pipeline config. Keys: topic_relevance (columns, prompt), " + "duplicate_detection (columns). Omit to skip L1." + ), + ) class AssessmentRunSummary(BaseModel): @@ -324,6 +360,8 @@ class AssessmentExportRow(BaseModel): row_id: str result_status: str input_data: dict[str, str] | None = None + topic_relevance: str | None = None + duplicate_detection: str | None = None output: str | None = None error: str | None = None response_id: str | None = None diff --git a/backend/app/services/assessment/l1/__init__.py b/backend/app/services/assessment/l1/__init__.py new file mode 100644 index 000000000..66e3a0374 --- /dev/null +++ b/backend/app/services/assessment/l1/__init__.py @@ -0,0 +1,3 @@ +from app.services.assessment.l1.pipeline import run_l1_pipeline + +__all__ = ["run_l1_pipeline"] diff --git a/backend/app/services/assessment/l1/duplicate_detection.py b/backend/app/services/assessment/l1/duplicate_detection.py new file mode 100644 index 000000000..608389c1d --- /dev/null +++ b/backend/app/services/assessment/l1/duplicate_detection.py @@ -0,0 +1,214 @@ +"""Duplicate detection filter for L1 pipeline.""" + +import json +import logging +import re +from typing import Any + +from google import genai +from google.genai import types + +logger = logging.getLogger(__name__) + +_VAGUE_SYS = """ +You are a strict VAGUENESS gate for the School Innovation Marathon (SIM) +duplicate-detection pipeline. Submissions come from Indian school students grades 6-12. +You run BEFORE corpus duplicate detection. Decide only if the submission has enough +surface area for corpus matching. NOT a quality gate. + +NOT VAGUE (let through to corpus check): +- Widely-known/textbook ideas (rainwater harvesting, anti-theft alarm) +- Weak novelty / unclear feasibility +- Hindi/Telugu/mixed Indian-language text +- Bad grammar or rambling if content present +- Long essays naming domain + audience + any mechanism + +VAGUE only when ALL: problem names no issue/target/domain, solution names no mechanism, +text is empty / aspirational ("make society better") / gibberish. + +DECISION: 0-1 clear dimensions present -> vague=true. 2+ -> vague=false. Borderline -> false. + +Output ONLY JSON: {"vague": true|false, "reason": "max 15 words"} +""" + +_DUP_SYS = """ +You are a strict duplicate-detection judge for an innovation competition corpus. + +Given a submitted idea, search the corpus and compare precisely. +Focus on MECHANISM of the solution, not category or theme. + +Verdict (exactly one): DUPLICATE / OVERLAP / PARTIAL_MATCH / UNIQUE + + DUPLICATE: Both problem AND solution mechanism substantially match a corpus entry. + OVERLAP: Either problem OR solution mechanism matches, other side clearly different. + PARTIAL_MATCH: Thematic/conceptual similarity only — same domain, different mechanism. + UNIQUE: Neither problem nor solution substantially matches anything in corpus. + +Response format (follow exactly): +Verdict: +Title: +Source: +URL: +Matching sentence: +Reason: + +RULES: +- UNIQUE -> output ONLY Verdict + Reason. +- NOT UNIQUE -> Title, Source, URL, Matching sentence ALL required. +- Source/URL MUST be VERBATIM from "SOURCE_URL:" line in retrieved chunk. +- NEVER write filenames, page numbers, or constructed URLs. +""" + + +def _build_combined(content_parts: dict[str, str]) -> str: + parts = [f"{col}:\n{val}" for col, val in content_parts.items() if val.strip()] + return "\n\n".join(parts) + + +def _check_vague( + text: str, + gemini_client: genai.Client, + model: str, +) -> tuple[bool, str]: + try: + response = gemini_client.models.generate_content( + model=model, + contents=f"Submission:\n\n{text}", + config=types.GenerateContentConfig( + system_instruction=_VAGUE_SYS, + response_mime_type="application/json", + temperature=0.0, + ), + ) + parsed = json.loads((response.text or "").strip()) + return bool(parsed.get("vague", False)), str(parsed.get("reason", "")) + except Exception as exc: + logger.warning("[_check_vague] Parse error — defaulting not vague | %s", exc) + return False, "(vague check error — defaulting to not vague)" + + +def _call_file_search( + text: str, + gemini_client: genai.Client, + model: str, + store_name: str, +) -> str: + response = gemini_client.models.generate_content( + model=model, + contents=f"Submitted idea to check for duplicates:\n\n{text}", + config=types.GenerateContentConfig( + system_instruction=_DUP_SYS, + tools=[ + types.Tool( + file_search=types.FileSearch(file_search_store_names=[store_name]) + ) + ], + temperature=0.0, + ), + ) + return response.text or "" + + +_VERDICT_VALUES = {"DUPLICATE", "OVERLAP", "PARTIAL_MATCH", "UNIQUE"} + + +def _parse_verdict(raw: str) -> dict[str, str | None]: + fields: dict[str, str | None] = { + "verdict": "", + "match_title": None, + "source_url": None, + "matching_sentence": None, + "reason": None, + } + keymap = { + "verdict": "verdict", + "title": "match_title", + "source": "source_url", + "url": "source_url", + "matching sentence": "matching_sentence", + "reason": "reason", + } + for line in (raw or "").splitlines(): + if ":" not in line: + continue + k, _, v = line.partition(":") + norm = re.sub(r"[^a-z\s]", "", k.strip().lower()).strip() + if norm in keymap: + fields[keymap[norm]] = v.strip() or None + + # Fallback: scan entire response for a known verdict token + if not fields["verdict"] or fields["verdict"] not in _VERDICT_VALUES: + m = re.search(r"\b(DUPLICATE|OVERLAP|PARTIAL_MATCH|UNIQUE)\b", raw or "") + if m: + fields["verdict"] = m.group(1) + logger.warning( + "[_parse_verdict] key-based parse missed verdict; regex fallback found: %s", + fields["verdict"], + ) + else: + logger.warning( + "[_parse_verdict] verdict not found in response. raw=%r", + (raw or "")[:500], + ) + + return fields + + +def run_duplicate_detection( + row_idx: int, + row: dict[str, str], + columns: list[str], + gemini_client: genai.Client, + model: str, + store_name: str, +) -> dict[str, Any]: + """Run duplicate detection on a single row. + + Returns a dict with: row_id, verdict, match_title, source_url, + matching_sentence, reason. + Always passthrough — never gates L2. + """ + content_parts = {col: row.get(col, "") for col in columns} + combined = _build_combined(content_parts) or "(empty submission)" + + try: + is_vague, vague_reason = _check_vague(combined, gemini_client, model) + except Exception as exc: + logger.warning( + "[run_duplicate_detection] Vague check failed row_%s | %s", row_idx, exc + ) + is_vague, vague_reason = False, f"(vague check error: {exc})" + + if is_vague: + return { + "row_id": f"row_{row_idx}", + "verdict": "VAGUE", + "match_title": None, + "source_url": None, + "matching_sentence": None, + "reason": vague_reason, + } + + try: + raw = _call_file_search(combined, gemini_client, model, store_name) + parsed = _parse_verdict(raw) + return { + "row_id": f"row_{row_idx}", + "verdict": parsed["verdict"] or "UNKNOWN", + "match_title": parsed["match_title"], + "source_url": parsed["source_url"], + "matching_sentence": parsed["matching_sentence"], + "reason": parsed["reason"], + } + except Exception as exc: + logger.warning( + "[run_duplicate_detection] File search failed row_%s | %s", row_idx, exc + ) + return { + "row_id": f"row_{row_idx}", + "verdict": "ERROR", + "match_title": None, + "source_url": None, + "matching_sentence": None, + "reason": str(exc)[:200], + } diff --git a/backend/app/services/assessment/l1/pipeline.py b/backend/app/services/assessment/l1/pipeline.py new file mode 100644 index 000000000..2a002e5e5 --- /dev/null +++ b/backend/app/services/assessment/l1/pipeline.py @@ -0,0 +1,225 @@ +"""L1 pipeline orchestrator. + +Runs two filters in series for each row: +1. Topic Relevance (go/no-go) — REJECT stops the row. +2. Duplicate Detection (passthrough) — only on ACCEPTED rows. + +""" + +import json +import logging +from concurrent.futures import ThreadPoolExecutor, as_completed +from typing import Any + +from sqlmodel import Session + +from app.core.batch.client import GeminiClient +from app.core.config import settings +from app.core.cloud import get_cloud_storage +from app.core.storage_utils import upload_jsonl_to_object_store +from app.models.assessment import AssessmentRun +from app.services.assessment.l1.duplicate_detection import run_duplicate_detection +from app.services.assessment.l1.topic_relevance import run_topic_relevance + +logger = logging.getLogger(__name__) + + +def _build_l1_result( + row_idx: int, + tr_result: dict[str, Any] | None, + dup_result: dict[str, Any] | None, +) -> dict[str, Any]: + return { + "row_id": f"row_{row_idx}", + "l1_passed": tr_result["verdict"] if tr_result else True, + "topic_relevance": { + "decision": tr_result["decision"], + "column_relevance": tr_result.get("column_relevance") or {}, + "reasoning": tr_result["reasoning"], + } + if tr_result + else None, + "duplicate_detection": dup_result, + } + + +def run_l1_pipeline( + run: AssessmentRun, + rows: list[dict[str, str]], + l1_config: dict[str, Any], + session: Session, + organization_id: int, + project_id: int, +) -> tuple[list[dict[str, str]], list[int], list[dict[str, Any]]]: + """Run L1 filters on all rows. + + Args: + run: The AssessmentRun record (used for S3 path and DB update). + rows: Full dataset rows loaded from object store. + l1_config: User-supplied config with topic_relevance and duplicate_detection keys. + session: DB session. + organization_id: For Gemini credential lookup. + project_id: For Gemini credential lookup and S3 storage. + + Returns: + (passed_rows, passed_indices, all_l1_results) + passed_rows: subset of rows where topic_relevance verdict=true. + passed_indices: original dataset indices of passed_rows (used to preserve row IDs in L2). + all_l1_results: one entry per input row (len == len(rows)). + """ + model = settings.ASSESSMENT_L1_GEMINI_MODEL + workers = settings.ASSESSMENT_L1_CONCURRENT_WORKERS + store_name = settings.ASSESSMENT_L1_DUPLICATE_STORE_NAME + + tr_config = l1_config.get("topic_relevance") or {} + dup_config = l1_config.get("duplicate_detection") or {} + + tr_columns: list[str] = tr_config.get("columns") or [] + tr_prompt: str = tr_config.get("prompt") or "" + dup_columns: list[str] = dup_config.get("columns") or [] + + tr_enabled = bool(tr_columns and tr_prompt) + dup_enabled = bool(dup_columns) + + if not tr_enabled and not dup_enabled: + logger.warning( + "[run_l1_pipeline] run_id=%s — no L1 filters configured, skipping L1", + run.id, + ) + return rows, list(range(len(rows))), [] + + gemini_client = GeminiClient.from_credentials( + session=session, + org_id=organization_id, + project_id=project_id, + ).client + + logger.info( + "[run_l1_pipeline] run_id=%s | rows=%s | model=%s | workers=%s | tr=%s | dup=%s", + run.id, + len(rows), + model, + workers, + tr_enabled, + dup_enabled, + ) + + # tr_results[idx] = None when TR disabled → no topic_relevance columns in export + tr_results: dict[int, dict[str, Any] | None] = {} + if tr_enabled: + with ThreadPoolExecutor(max_workers=workers) as executor: + futs = { + executor.submit( + run_topic_relevance, + idx, + row, + tr_columns, + tr_prompt, + gemini_client, + model, + ): idx + for idx, row in enumerate(rows) + } + for fut in as_completed(futs): + idx = futs[fut] + try: + tr_results[idx] = fut.result() + except Exception as exc: + logger.warning( + "[run_l1_pipeline] TR future error row_%s | %s", idx, exc + ) + tr_results[idx] = { + "row_id": f"row_{idx}", + "verdict": True, + "decision": "ACCEPT", + "column_relevance": {}, + "reasoning": f"(future error — defaulting to pass) {exc}", + } + passed_indices = [idx for idx, r in tr_results.items() if r and r["verdict"]] + else: + for idx in range(len(rows)): + tr_results[idx] = None + passed_indices = list(range(len(rows))) + + rejected_count = len(rows) - len(passed_indices) + logger.info( + "[run_l1_pipeline] run_id=%s | TR done | passed=%s | rejected=%s", + run.id, + len(passed_indices), + rejected_count, + ) + + dup_results: dict[int, dict[str, Any]] = {} + if dup_columns and passed_indices: + with ThreadPoolExecutor(max_workers=workers) as executor: + futs = { + executor.submit( + run_duplicate_detection, + idx, + rows[idx], + dup_columns, + gemini_client, + model, + store_name, + ): idx + for idx in passed_indices + } + for fut in as_completed(futs): + idx = futs[fut] + try: + dup_results[idx] = fut.result() + except Exception as exc: + logger.warning( + "[run_l1_pipeline] DUP future error row_%s | %s", idx, exc + ) + dup_results[idx] = { + "row_id": f"row_{idx}", + "verdict": "ERROR", + "match_title": None, + "source_url": None, + "matching_sentence": None, + "reason": str(exc)[:200], + } + + all_l1_results: list[dict[str, Any]] = [ + _build_l1_result(idx, tr_results[idx], dup_results.get(idx)) + for idx in range(len(rows)) + ] + + l1_object_store_url: str | None = None + try: + storage = get_cloud_storage(session=session, project_id=project_id) + l1_object_store_url = upload_jsonl_to_object_store( + storage=storage, + results=all_l1_results, + filename="l1_results.json", + subdirectory=f"assessment/run-{run.id}/l1", + format="json", + ) + logger.info( + "[run_l1_pipeline] run_id=%s | L1 results uploaded to %s", + run.id, + l1_object_store_url, + ) + except Exception as exc: + logger.error( + "[run_l1_pipeline] run_id=%s | S3 upload failed | %s", + run.id, + exc, + exc_info=True, + ) + + from app.crud.assessment.core import update_assessment_run_l1_stats + + update_assessment_run_l1_stats( + session=session, + run=run, + l1_object_store_url=l1_object_store_url, + l1_total_rows=len(rows), + l1_total_passed=len(passed_indices), + l1_total_rejected=rejected_count, + ) + + sorted_passed_indices = sorted(passed_indices) + passed_rows = [rows[idx] for idx in sorted_passed_indices] + return passed_rows, sorted_passed_indices, all_l1_results diff --git a/backend/app/services/assessment/l1/topic_relevance.py b/backend/app/services/assessment/l1/topic_relevance.py new file mode 100644 index 000000000..42516d27b --- /dev/null +++ b/backend/app/services/assessment/l1/topic_relevance.py @@ -0,0 +1,93 @@ +"""Topic relevance filter for L1 pipeline. +""" + +import json +import logging +from typing import Any + +from google import genai +from google.genai import types + +logger = logging.getLogger(__name__) + + +def _build_output_schema(columns: list[str]) -> dict[str, Any]: + """Build output schema: locked decision + per-column relevance booleans + reasoning.""" + props: dict[str, Any] = { + "decision": { + "type": "string", + "enum": ["ACCEPT", "REJECT"], + "description": "Final verdict. ACCEPT to proceed to full evaluation, REJECT to stop here.", + }, + } + required = ["decision"] + + for col in columns: + props[col] = { + "type": "boolean", + "description": f"Whether the '{col}' column content is relevant to the topic.", + } + required.append(col) + + props["reasoning"] = { + "type": "string", + "description": "Explanation of the verdict and per-column relevance assessment.", + } + required.append("reasoning") + + return {"type": "object", "properties": props, "required": required} + + +def run_topic_relevance( + row_idx: int, + row: dict[str, str], + columns: list[str], + user_prompt: str, + gemini_client: genai.Client, + model: str, +) -> dict[str, Any]: + """Run topic relevance check on a single row. + + System instruction = user_prompt (the evaluation rubric/criteria). + User content = dict of {column_name: value} for the selected columns. + Output schema enforced: decision (ACCEPT/REJECT) + reasoning. + On error defaults to verdict=True (fail-open). + """ + user_content = json.dumps({col: row.get(col, "") or "" for col in columns}) + output_schema = _build_output_schema(columns) + + try: + response = gemini_client.models.generate_content( + model=model, + contents=user_content, + config=types.GenerateContentConfig( + system_instruction=user_prompt.strip(), + response_mime_type="application/json", + response_schema=output_schema, + temperature=0.0, + ), + ) + raw = (response.text or "").strip() + parsed = json.loads(raw) + decision = str(parsed.get("decision", "ACCEPT")).upper() + column_relevance = {col: bool(parsed.get(col, True)) for col in columns} + return { + "row_id": f"row_{row_idx}", + "verdict": decision == "ACCEPT", + "decision": decision, + "column_relevance": column_relevance, + "reasoning": str(parsed.get("reasoning", "")), + } + except Exception as exc: + logger.warning( + "[run_topic_relevance] row_%s error — defaulting verdict=True | %s", + row_idx, + exc, + ) + return { + "row_id": f"row_{row_idx}", + "verdict": True, + "decision": "ACCEPT", + "column_relevance": {col: True for col in columns}, + "reasoning": f"(evaluation error — defaulting to pass) {exc}", + } diff --git a/backend/app/services/assessment/service.py b/backend/app/services/assessment/service.py index 45a283ea5..d03eb13a7 100644 --- a/backend/app/services/assessment/service.py +++ b/backend/app/services/assessment/service.py @@ -1,9 +1,9 @@ """Assessment run orchestration service.""" import logging -from typing import Any from uuid import UUID +from asgi_correlation_id import correlation_id from fastapi import HTTPException from sqlmodel import Session @@ -13,9 +13,7 @@ get_assessment_dataset_by_id, get_assessment_runs_for_assessment, recompute_assessment_status, - update_assessment_run_status, ) -from app.crud.assessment.batch import submit_assessment_batch from app.crud.config import ConfigCrud from app.crud.evaluations.core import resolve_evaluation_config from app.models.assessment import ( @@ -81,6 +79,7 @@ def _build_retry_request( attachments=[AssessmentAttachment.model_validate(item) for item in attachments], output_schema=assessment_input.get("output_schema"), configs=configs, + l1_config=assessment_input.get("l1_config"), ) @@ -90,11 +89,13 @@ def start_assessment( organization_id: int, project_id: int, ) -> AssessmentResponse: - """Start an assessment run request. + """Validate, create Assessment + AssessmentRun records, dispatch Celery tasks. - Validates the dataset, resolves each config, creates one AssessmentRun per config, - and kicks off batch processing for each. + Each run is created with status='pending' and handed off to a Celery worker + that runs L1 filtering then submits the L2 batch. """ + from app.celery.tasks.job_execution import run_assessment_run + logger.info( "[start_assessment] Starting | experiment=%s | dataset_id=%s | configs=%s | org_id=%s", request.experiment_name, @@ -110,7 +111,7 @@ def start_assessment( project_id=project_id, ) - assessment_input: dict[str, Any] = { + assessment_input: dict = { "prompt_template": request.prompt_template, "system_instruction": request.system_instruction, "text_columns": request.text_columns, @@ -118,12 +119,13 @@ def start_assessment( } if request.output_schema: assessment_input["output_schema"] = request.output_schema + if request.l1_config: + assessment_input["l1_config"] = request.l1_config config_crud = ConfigCrud(session=session, project_id=project_id) resolved_configs = [] for cfg in request.configs: - # Assessment runs must use configs explicitly tagged for assessment use. parent_config = config_crud.read_one(cfg.config_id) if parent_config is not None and parent_config.tag != ConfigTag.ASSESSMENT: tag_value = ( @@ -165,7 +167,7 @@ def start_assessment( f"Supported providers: {sorted(_SUPPORTED_BATCH_PROVIDERS)}" ), ) - resolved_configs.append((cfg, config_blob)) + resolved_configs.append(cfg) assessment = create_assessment( session=session, @@ -176,54 +178,30 @@ def start_assessment( ) runs: list[AssessmentRun] = [] - try: - for cfg, config_blob in resolved_configs: - run = create_assessment_run( - session=session, - assessment_id=assessment.id, - config_id=cfg.config_id, - config_version=cfg.config_version, - assessment_input=assessment_input, - ) - - try: - batch_job = submit_assessment_batch( - session=session, - run=run, - assessment=assessment, - dataset=dataset, - config_blob=config_blob, - assessment_input=assessment_input, - organization_id=organization_id, - project_id=project_id, - ) + trace_id = correlation_id.get() or "" - run = update_assessment_run_status( - session=session, - run=run, - status="processing", - batch_job_id=batch_job.id, - total_items=batch_job.total_items, - ) + for cfg in resolved_configs: + run = create_assessment_run( + session=session, + assessment_id=assessment.id, + config_id=cfg.config_id, + config_version=cfg.config_version, + assessment_input=assessment_input, + ) + runs.append(run) - except Exception as e: - logger.error( - "[start_assessment] Failed to submit batch for run %s: %s", - run.id, - e, - exc_info=True, - ) - run = update_assessment_run_status( - session=session, - run=run, - status="failed", - error_message="Batch submission failed. Please try again or contact support.", - ) + run_assessment_run.delay( + run_id=run.id, + organization_id=organization_id, + project_id=project_id, + trace_id=trace_id, + ) - runs.append(run) - except Exception: - recompute_assessment_status(session=session, assessment_id=assessment.id) - raise + logger.info( + "[start_assessment] Dispatched Celery task | run_id=%s | config_id=%s", + run.id, + cfg.config_id, + ) recompute_assessment_status(session=session, assessment_id=assessment.id) @@ -242,13 +220,13 @@ def start_assessment( num_configs=len(runs), runs=[ AssessmentRunSummary( - run_id=completed_run.id, - assessment_id=completed_run.assessment_id, - config_id=str(completed_run.config_id), - config_version=completed_run.config_version, - status=completed_run.status, + run_id=run.id, + assessment_id=run.assessment_id, + config_id=str(run.config_id), + config_version=run.config_version, + status=run.status, ) - for completed_run in runs + for run in runs ], ) diff --git a/backend/app/services/assessment/tasks.py b/backend/app/services/assessment/tasks.py new file mode 100644 index 000000000..66f644050 --- /dev/null +++ b/backend/app/services/assessment/tasks.py @@ -0,0 +1,196 @@ +"""Celery task logic for running a single assessment run (L1 → L2 batch submit).""" + +import logging + +from sqlmodel import Session + +from app.core.db import engine +from app.crud.assessment import ( + get_assessment_dataset_by_id, + recompute_assessment_status, + update_assessment_run_status, +) +from app.crud.assessment.batch import _load_dataset_rows, submit_assessment_batch +from app.crud.config import ConfigCrud +from app.crud.evaluations.core import resolve_evaluation_config +from app.models.assessment import Assessment, AssessmentRun +from app.models.config.config import ConfigTag +from app.services.assessment.l1 import run_l1_pipeline + +logger = logging.getLogger(__name__) + + +def execute_assessment_run( + run_id: int, + organization_id: int, + project_id: int, +) -> None: + """Run L1 filtering then submit L2 batch for one AssessmentRun. + + Status transitions: + pending → l1_processing → l1_failed (stop) + → l2_processing → (cron handles rest) + pending → l2_processing (when no l1_config) + """ + with Session(engine) as session: + run = session.get(AssessmentRun, run_id) + if run is None: + logger.error("[execute_assessment_run] run_id=%s not found", run_id) + return + + assessment = session.get(Assessment, run.assessment_id) + if assessment is None: + logger.error( + "[execute_assessment_run] parent assessment %s not found for run %s", + run.assessment_id, + run_id, + ) + return + + assessment_input = run.input or {} + dataset_id = assessment.dataset_id + + dataset = get_assessment_dataset_by_id( + session=session, + dataset_id=dataset_id, + organization_id=organization_id, + project_id=project_id, + ) + + config_crud = ConfigCrud(session=session, project_id=project_id) + parent_config = config_crud.read_one(run.config_id) + if parent_config is not None and parent_config.tag != ConfigTag.ASSESSMENT: + logger.error( + "[execute_assessment_run] config %s has wrong tag for run %s", + run.config_id, + run_id, + ) + update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message="Config tag is not ASSESSMENT.", + ) + recompute_assessment_status(session=session, assessment_id=assessment.id) + return + + config_blob, error = resolve_evaluation_config( + session=session, + config_id=run.config_id, + config_version=run.config_version, + project_id=project_id, + tag=ConfigTag.ASSESSMENT, + ) + if error or config_blob is None: + logger.error( + "[execute_assessment_run] config resolution failed run_id=%s: %s", + run_id, + error, + ) + update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message=f"Config resolution failed: {error}", + ) + recompute_assessment_status(session=session, assessment_id=assessment.id) + return + + all_rows = _load_dataset_rows(session=session, dataset=dataset) + if not all_rows: + logger.error( + "[execute_assessment_run] dataset %s has no rows for run %s", + dataset_id, + run_id, + ) + update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message="Dataset has no rows.", + ) + recompute_assessment_status(session=session, assessment_id=assessment.id) + return + + # L1 pipeline + rows_for_l2 = all_rows + row_indices_for_l2: list[int] | None = None + l1_config = assessment_input.get("l1_config") + if l1_config: + update_assessment_run_status( + session=session, run=run, status="l1_processing" + ) + try: + rows_for_l2, row_indices_for_l2, _ = run_l1_pipeline( + run=run, + rows=all_rows, + l1_config=l1_config, + session=session, + organization_id=organization_id, + project_id=project_id, + ) + logger.info( + "[execute_assessment_run] L1 done | run_id=%s | rows_to_l2=%s / %s", + run_id, + len(rows_for_l2), + len(all_rows), + ) + except Exception as l1_exc: + logger.error( + "[execute_assessment_run] L1 failed run_id=%s | %s", + run_id, + l1_exc, + exc_info=True, + ) + update_assessment_run_status( + session=session, + run=run, + status="l1_failed", + error_message=f"L1 pipeline failed: {l1_exc}", + ) + recompute_assessment_status( + session=session, assessment_id=assessment.id + ) + return # L2 does not run when L1 fails + + # L2 batch submit + try: + batch_job = submit_assessment_batch( + session=session, + run=run, + assessment=assessment, + dataset=dataset, + config_blob=config_blob, + assessment_input=assessment_input, + organization_id=organization_id, + project_id=project_id, + preloaded_rows=rows_for_l2, + row_indices=row_indices_for_l2, + ) + update_assessment_run_status( + session=session, + run=run, + status="l2_processing", + batch_job_id=batch_job.id, + total_items=batch_job.total_items, + ) + logger.info( + "[execute_assessment_run] L2 batch submitted | run_id=%s | batch_job_id=%s", + run_id, + batch_job.id, + ) + except Exception as e: + logger.error( + "[execute_assessment_run] L2 batch submit failed run_id=%s: %s", + run_id, + e, + exc_info=True, + ) + update_assessment_run_status( + session=session, + run=run, + status="failed", + error_message="Batch submission failed. Please try again or contact support.", + ) + + recompute_assessment_status(session=session, assessment_id=assessment.id) diff --git a/backend/app/services/assessment/utils/export.py b/backend/app/services/assessment/utils/export.py index ca273afc6..d244ecd08 100644 --- a/backend/app/services/assessment/utils/export.py +++ b/backend/app/services/assessment/utils/export.py @@ -22,6 +22,8 @@ from app.services.assessment.utils.parsing import parse_stored_results, usage_totals from app.utils import APIResponse +_L1_JSON_COLUMNS = ["topic_relevance", "duplicate_detection"] + logger = logging.getLogger(__name__) @@ -34,6 +36,29 @@ def _load_dataset_rows( return load_dataset_rows(session, dataset) +def _load_l1_results( + session: Session, + run: AssessmentRun, + assessment: Assessment, +) -> dict[str, dict[str, Any]]: + """Load L1 results from object store, keyed by row_id. Returns {} if unavailable.""" + if not run.l1_object_store_url: + return {} + try: + storage = get_cloud_storage(session, project_id=assessment.project_id) + body = storage.stream(run.l1_object_store_url) + raw = body.read().decode("utf-8") + results: list[dict[str, Any]] = json.loads(raw) + return {str(item["row_id"]): item for item in results if "row_id" in item} + except Exception as exc: + logger.warning( + "[_load_l1_results] Failed to load L1 results for run id=%s: %s", + run.id, + exc, + ) + return {} + + def _safe_filename_part(value: str) -> str: """Build a filesystem-safe filename component.""" sanitized = re.sub(r"[^A-Za-z0-9._-]+", "_", value).strip("._") @@ -113,86 +138,99 @@ def _drop_empty_columns( return pruned, non_empty_fields +def _parse_json_col(raw: Any) -> dict[str, Any] | None: + if raw is None: + return None + if isinstance(raw, dict): + return raw + if isinstance(raw, str): + try: + parsed = json.loads(raw) + return parsed if isinstance(parsed, dict) else None + except (json.JSONDecodeError, TypeError): + return None + return None + + def _expand_output_columns( row_payload: list[dict[str, Any]], ) -> tuple[list[dict[str, Any]], list[str]]: - """Expand the ``output`` field into separate columns when it contains valid JSON. + """Expand ``output``, ``topic_relevance``, and ``duplicate_detection`` JSON columns + into separate flat columns when they contain valid JSON objects. Returns: (expanded_rows, ordered_fieldnames) """ - # First expand input columns row_payload, input_col_names = _expand_input_columns(row_payload) + json_expand_cols = {"output", "input_data"} | set(_L1_JSON_COLUMNS) base_fields = [ field for field in AssessmentExportRow.model_fields.keys() - if field not in ("output", "input_data") + if field not in json_expand_cols ] - parsed_outputs: list[dict[str, Any] | None] = [] - output_keys: list[str] = [] - seen_keys: dict[str, None] = {} # ordered set + # L1 columns are prefixed with their parent name to avoid key collisions + parsed_cols: dict[str, list[dict[str, Any] | None]] = { + col: [] for col in ["output"] + _L1_JSON_COLUMNS + } + col_keys: dict[str, list[str]] = {col: [] for col in ["output"] + _L1_JSON_COLUMNS} + col_seen: dict[str, dict[str, None]] = { + col: {} for col in ["output"] + _L1_JSON_COLUMNS + } has_unparsed_output = False for row in row_payload: - raw = row.get("output") - if raw is None: - parsed_outputs.append(None) - continue - - if isinstance(raw, str): - try: - parsed = json.loads(raw) - except (json.JSONDecodeError, TypeError): - parsed = None - elif isinstance(raw, dict): - parsed = raw - else: - parsed = None - - if not isinstance(parsed, dict): - has_unparsed_output = True - parsed_outputs.append(None) - continue - - parsed_outputs.append(parsed) - for output_key in parsed: - if output_key not in seen_keys: - seen_keys[output_key] = None - output_keys.append(output_key) - - if not output_keys: - # Keep original layout with output as a single column - fieldnames = input_col_names + list(AssessmentExportRow.model_fields.keys()) - fieldnames = [field for field in fieldnames if field != "input_data"] - return row_payload, fieldnames + for col in ["output"] + _L1_JSON_COLUMNS: + parsed = _parse_json_col(row.get(col)) + if parsed is None and col == "output" and row.get(col) is not None: + has_unparsed_output = True + parsed_cols[col].append(parsed) + if parsed: + for k in parsed: + prefixed = f"{col}_{k}" if col in _L1_JSON_COLUMNS else k + if prefixed not in col_seen[col]: + col_seen[col][prefixed] = None + col_keys[col].append(prefixed) + + def _get_prefixed(parsed: dict[str, Any] | None, col: str) -> dict[str, Any]: + if not parsed: + return {} + if col in _L1_JSON_COLUMNS: + return {f"{col}_{k}": v for k, v in parsed.items()} + return parsed # Build expanded rows expanded: list[dict[str, Any]] = [] - for row, parsed in zip(row_payload, parsed_outputs, strict=True): - new_row = {col: val for col, val in row.items() if col != "output"} - if parsed: - for output_key in output_keys: - new_row[output_key] = parsed.get(output_key) - else: - for output_key in output_keys: - new_row[output_key] = None - if row.get("output") is not None: - new_row["output_raw"] = row.get("output") + for i, row in enumerate(row_payload): + new_row = {k: v for k, v in row.items() if k not in json_expand_cols} + for col in ["output"] + _L1_JSON_COLUMNS: + parsed = parsed_cols[col][i] + keys = col_keys[col] + prefixed_vals = _get_prefixed(parsed, col) + if prefixed_vals: + for k in keys: + new_row[k] = prefixed_vals.get(k) + else: + for k in keys: + new_row[k] = None + if col == "output" and row.get("output") is not None: + new_row["output_raw"] = row.get("output") expanded.append(new_row) - # Build fieldnames: input columns + base fields + output columns - output_idx = base_fields.index("result_status") + 1 # after result_status - fieldnames = ( - input_col_names - + base_fields[:output_idx] - + output_keys - + base_fields[output_idx:] - ) + l1_keys = col_keys["topic_relevance"] + col_keys["duplicate_detection"] + output_keys = col_keys["output"] + + all_output_keys = l1_keys + output_keys + if not all_output_keys: + fieldnames = input_col_names + list(AssessmentExportRow.model_fields.keys()) + fieldnames = [f for f in fieldnames if f != "input_data"] + return row_payload, fieldnames + + fieldnames = input_col_names + l1_keys + output_keys + base_fields if has_unparsed_output: fieldnames.insert( - len(input_col_names) + output_idx + len(output_keys), "output_raw" + len(input_col_names) + len(l1_keys) + len(output_keys), "output_raw" ) return expanded, fieldnames @@ -212,7 +250,6 @@ def serialize_export_rows( "application/json", ) - # For CSV/XLSX, expand output keys into separate columns expanded, fieldnames = _expand_output_columns(row_payload) if export_format == "csv": @@ -230,7 +267,6 @@ def serialize_export_rows( detail="XLSX export requires pandas/openpyxl support in the backend runtime", ) from exc - # XLSX shows input columns + output columns only (no metadata fields). metadata_fields = { field for field in AssessmentExportRow.model_fields.keys() @@ -376,59 +412,154 @@ def _load_dataset_rows_for_run( return [] +def _extract_l1_json_columns( + l1_item: dict[str, Any] | None, +) -> dict[str, Any]: + """Return topic_relevance and duplicate_detection as JSON strings for export expansion.""" + if not l1_item: + return {"topic_relevance": None, "duplicate_detection": None} + + tr = l1_item.get("topic_relevance") + dup = l1_item.get("duplicate_detection") + + tr_flat: dict[str, Any] | None = None + if tr: + tr_flat = {} + for col, val in (tr.get("column_relevance") or {}).items(): + tr_flat[col] = val + tr_flat["decision"] = tr.get("decision") + tr_flat["reasoning"] = tr.get("reasoning") + + dup_flat: dict[str, Any] | None = None + if dup: + dup_flat = {k: v for k, v in dup.items() if k != "row_id"} + + return { + "topic_relevance": json.dumps(tr_flat, ensure_ascii=False) if tr_flat else None, + "duplicate_detection": json.dumps(dup_flat, ensure_ascii=False) + if dup_flat + else None, + } + + def load_export_rows_for_run( session: Session, run: AssessmentRun, assessment: Assessment | None = None, ) -> list[AssessmentExportRow]: - """Load flattened export rows for a single child assessment run.""" - if not run.batch_job_id: + """Load flattened export rows for a single child assessment run. + + When L1 results exist, ALL dataset rows are included in output. + L1-rejected rows have L1 columns filled and L2 columns empty. + L1-passed rows have all columns filled. + Without L1, behaviour is unchanged (only L2 result rows returned). + """ + if assessment is None: + assessment = session.get(Assessment, run.assessment_id) + if assessment is None: logger.warning( - "[load_export_rows_for_run] No batch_job_id for run id=%s", run.id + "[load_export_rows_for_run] Parent assessment missing for run id=%s", + run.id, ) return [] - batch_job = get_batch_job(session=session, batch_job_id=run.batch_job_id) - if not batch_job: + dataset = session.get(EvaluationDataset, assessment.dataset_id) + dataset_name = dataset.name if dataset else None + dataset_rows = _load_dataset_rows_for_run(session, run, assessment) + + # Load L1 results (empty dict if no L1 was run) + l1_by_row_id = _load_l1_results(session, run, assessment) + + # Load L2 results (may be None if batch not complete) + l2_by_row_id: dict[str, dict[str, Any]] = {} + if run.batch_job_id: + batch_job = get_batch_job(session=session, batch_job_id=run.batch_job_id) + if batch_job: + parsed_results = _load_parsed_results_for_run( + session=session, run=run, batch_job=batch_job + ) + if parsed_results: + l2_by_row_id = { + str(item["row_id"]): item + for item in parsed_results + if "row_id" in item + } + + has_l1 = bool(l1_by_row_id) + + if has_l1 and dataset_rows: + # All rows in output — build from full dataset + export_rows: list[AssessmentExportRow] = [] + for row_idx, input_data in enumerate(dataset_rows): + row_id_str = f"row_{row_idx}" + l1_item = l1_by_row_id.get(row_id_str) + l1_cols = _extract_l1_json_columns(l1_item) + l2_item = l2_by_row_id.get(row_id_str) + + input_tokens, output_tokens, total_tokens = usage_totals( + l2_item.get("usage") if l2_item else None + ) + l1_passed = (l1_item or {}).get("l1_passed", True) + result_status = ( + "l1_rejected" + if not l1_passed + else ("failed" if l2_item and l2_item.get("error") else "passed") + ) + + export_rows.append( + AssessmentExportRow( + assessment_id=run.assessment_id, + experiment_name=assessment.experiment_name, + dataset_id=assessment.dataset_id, + dataset_name=dataset_name, + run_id=run.id, + run_name=assessment.experiment_name, + run_status=run.status, + config_id=run.config_id, + config_version=run.config_version, + row_id=row_id_str, + result_status=result_status, + input_data=input_data, + topic_relevance=l1_cols.get("topic_relevance"), + duplicate_detection=l1_cols.get("duplicate_detection"), + output=l2_item.get("output") if l2_item else None, + error=l2_item.get("error") if l2_item else None, + response_id=l2_item.get("response_id") if l2_item else None, + input_tokens=input_tokens, + output_tokens=output_tokens, + total_tokens=total_tokens, + updated_at=run.updated_at, + ) + ) + return export_rows + + # No L1 — original behaviour: only L2 result rows + if not run.batch_job_id: logger.warning( - "[load_export_rows_for_run] Missing batch job for run id=%s", - run.id, + "[load_export_rows_for_run] No batch_job_id for run id=%s", run.id ) return [] - if assessment is None: - assessment = session.get(Assessment, run.assessment_id) - if assessment is None: + batch_job = get_batch_job(session=session, batch_job_id=run.batch_job_id) + if not batch_job: logger.warning( - "[load_export_rows_for_run] Parent assessment missing for run id=%s", - run.id, + "[load_export_rows_for_run] Missing batch job for run id=%s", run.id ) return [] parsed_results = _load_parsed_results_for_run( - session=session, - run=run, - batch_job=batch_job, + session=session, run=run, batch_job=batch_job ) - if parsed_results is None: - return [] - if not parsed_results: logger.warning( "[load_export_rows_for_run] Parsed results empty for run id=%s", run.id ) return [] - dataset_rows = _load_dataset_rows_for_run(session, run, assessment) - dataset = session.get(EvaluationDataset, assessment.dataset_id) - dataset_name = dataset.name if dataset else None - - export_rows: list[AssessmentExportRow] = [] + export_rows = [] for item in parsed_results: input_tokens, output_tokens, total_tokens = usage_totals(item.get("usage")) - - # Correlate with original input row via row_id (format: "row_{idx}") - input_data: dict[str, str] | None = None + input_data = None row_id_str = str(item.get("row_id", "")) if dataset_rows and row_id_str.startswith("row_"): try: From b4128290c424729997a409ea3977d3c28c71a66d Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Thu, 28 May 2026 21:33:42 +0530 Subject: [PATCH 2/8] feat(export): Expand output columns to include topic relevance and duplicate detection --- .../app/services/assessment/utils/export.py | 22 ++++++++----------- 1 file changed, 9 insertions(+), 13 deletions(-) diff --git a/backend/app/services/assessment/utils/export.py b/backend/app/services/assessment/utils/export.py index d244ecd08..4b299151f 100644 --- a/backend/app/services/assessment/utils/export.py +++ b/backend/app/services/assessment/utils/export.py @@ -154,7 +154,7 @@ def _parse_json_col(raw: Any) -> dict[str, Any] | None: def _expand_output_columns( row_payload: list[dict[str, Any]], -) -> tuple[list[dict[str, Any]], list[str]]: +) -> tuple[list[dict[str, Any]], list[str], list[str], list[str], list[str]]: """Expand ``output``, ``topic_relevance``, and ``duplicate_detection`` JSON columns into separate flat columns when they contain valid JSON objects. @@ -225,7 +225,7 @@ def _get_prefixed(parsed: dict[str, Any] | None, col: str) -> dict[str, Any]: if not all_output_keys: fieldnames = input_col_names + list(AssessmentExportRow.model_fields.keys()) fieldnames = [f for f in fieldnames if f != "input_data"] - return row_payload, fieldnames + return row_payload, fieldnames, input_col_names, [], [] fieldnames = input_col_names + l1_keys + output_keys + base_fields if has_unparsed_output: @@ -233,7 +233,7 @@ def _get_prefixed(parsed: dict[str, Any] | None, col: str) -> dict[str, Any]: len(input_col_names) + len(l1_keys) + len(output_keys), "output_raw" ) - return expanded, fieldnames + return expanded, fieldnames, input_col_names, l1_keys, output_keys def serialize_export_rows( @@ -244,13 +244,13 @@ def serialize_export_rows( row_payload = [row.model_dump(mode="json") for row in export_rows] if export_format == "json": - expanded, _ = _expand_output_columns(row_payload) + expanded, *_ = _expand_output_columns(row_payload) return ( json.dumps(expanded, ensure_ascii=False, indent=2).encode("utf-8"), "application/json", ) - expanded, fieldnames = _expand_output_columns(row_payload) + expanded, fieldnames, input_col_names, l1_keys, output_keys = _expand_output_columns(row_payload) if export_format == "csv": output = io.StringIO() @@ -267,14 +267,10 @@ def serialize_export_rows( detail="XLSX export requires pandas/openpyxl support in the backend runtime", ) from exc - metadata_fields = { - field - for field in AssessmentExportRow.model_fields.keys() - if field not in ("output", "input_data") - } - excel_fields = [field for field in fieldnames if field not in metadata_fields] + # Explicit ordering: inputs → L1 topic relevance → L1 duplicate detection → L2 output + excel_fields = input_col_names + l1_keys + output_keys if not excel_fields: - excel_fields = ["output"] + excel_fields = output_keys or ["output"] # Drop columns where every row is null/empty expanded, excel_fields = _drop_empty_columns(expanded, excel_fields) @@ -294,7 +290,7 @@ def build_json_export_rows( ) -> list[dict[str, Any]]: """Return JSON rows with structured output expanded into top-level keys.""" row_payload = [row.model_dump(mode="json") for row in export_rows] - expanded, _ = _expand_output_columns(row_payload) + expanded, *_ = _expand_output_columns(row_payload) return expanded From c12ac18bd185cf0f327011050082d99dbc1bb386 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Sun, 31 May 2026 08:00:55 +0530 Subject: [PATCH 3/8] feat(post-processing): Implement post-processing configuration for assessment runs --- .../docs/assessment/update_post_processing.md | 15 ++ backend/app/api/routes/assessment/runs.py | 39 +++- backend/app/crud/assessment/__init__.py | 2 + backend/app/crud/assessment/core.py | 25 +++ backend/app/models/assessment.py | 10 +- backend/app/services/assessment/service.py | 3 + .../app/services/assessment/utils/export.py | 39 +++- .../assessment/utils/post_processing.py | 184 ++++++++++++++++++ 8 files changed, 306 insertions(+), 11 deletions(-) create mode 100644 backend/app/api/docs/assessment/update_post_processing.md create mode 100644 backend/app/services/assessment/utils/post_processing.py diff --git a/backend/app/api/docs/assessment/update_post_processing.md b/backend/app/api/docs/assessment/update_post_processing.md new file mode 100644 index 000000000..0d6f3278a --- /dev/null +++ b/backend/app/api/docs/assessment/update_post_processing.md @@ -0,0 +1,15 @@ +Save post-processing config for a single assessment run. + +Stores the config inside the run's `input` JSON blob (key +`post_processing_config`). It is applied at export/preview time and never +re-runs the LLM, so it can be edited after the run completes. + +The config has three optional sections: + +- `computed_columns`: derived columns from formulas, e.g. + `{"name": "Total_Score", "formula": "@Novelty_score + @Usefulness_score"}`. + Formulas reference columns with `@` and support `+ - * /` and parentheses. +- `filter`: row filters combined with AND logic. +- `sort`: sort rules applied in priority order. + +Pass `null` (or an empty body) to clear post-processing for the run. diff --git a/backend/app/api/routes/assessment/runs.py b/backend/app/api/routes/assessment/runs.py index 18398eeb0..3c3abd57a 100644 --- a/backend/app/api/routes/assessment/runs.py +++ b/backend/app/api/routes/assessment/runs.py @@ -3,7 +3,7 @@ import logging from typing import Any, Literal -from fastapi import APIRouter, Depends, Query +from fastapi import APIRouter, Body, Depends, HTTPException, Query from fastapi.responses import StreamingResponse from app.api.deps import AuthContextDep, SessionDep @@ -12,6 +12,7 @@ get_assessment_by_id, get_assessment_run_by_id as get_run_by_id, list_assessment_runs as list_runs, + update_run_post_processing_config, ) from app.models.assessment import ( Assessment, @@ -33,6 +34,7 @@ load_export_rows_for_run, sort_export_rows, ) +from app.services.assessment.utils.post_processing import apply_post_processing from app.utils import APIResponse, load_description logger = logging.getLogger(__name__) @@ -68,6 +70,7 @@ def _build_run_public( l1_total_rows=run.l1_total_rows, l1_total_passed=run.l1_total_passed, l1_total_rejected=run.l1_total_rejected, + post_processing_config=(run.input or {}).get("post_processing_config"), inserted_at=run.inserted_at, updated_at=run.updated_at, ) @@ -215,12 +218,44 @@ def export_assessment_run_results( ) ) + post_processing_config = (run.input or {}).get("post_processing_config") or None base_label = assessment.experiment_name if assessment else f"run_{run.id}" + if export_format != "json": return build_export_response( export_rows=export_rows, export_format=export_format, base_name=f"{base_label}_run_{run.id}_results", + post_processing_config=post_processing_config, ) - return APIResponse.success_response(data=build_json_export_rows(export_rows)) + rows = build_json_export_rows(export_rows) + rows = apply_post_processing(rows, post_processing_config) + return APIResponse.success_response(data=rows) + + +@router.patch( + "/runs/{run_id}/post-processing", + description=load_description("assessment/update_post_processing.md"), + response_model=APIResponse[AssessmentRunPublic], + dependencies=[Depends(require_permission(Permission.REQUIRE_PROJECT))], +) +def update_post_processing( + run_id: int, + session: SessionDep, + auth_context: AuthContextDep, + config: dict[str, Any] | None = Body(default=None), +) -> APIResponse[AssessmentRunPublic]: + """Save post-processing config (computed columns, sort, filter) for a run.""" + run = get_run_by_id( + session=session, + run_id=run_id, + organization_id=auth_context.organization_.id, + project_id=auth_context.project_.id, + ) + if run is None: + raise HTTPException(status_code=404, detail="Run not found") + + run = update_run_post_processing_config(session=session, run=run, config=config) + + return APIResponse.success_response(data=_build_run_public(session, run)) diff --git a/backend/app/crud/assessment/__init__.py b/backend/app/crud/assessment/__init__.py index 8e5c9984d..8e623e3a7 100644 --- a/backend/app/crud/assessment/__init__.py +++ b/backend/app/crud/assessment/__init__.py @@ -15,6 +15,7 @@ recompute_assessment_status, update_assessment_run_l1_stats, update_assessment_run_status, + update_run_post_processing_config, ) from app.crud.assessment.dataset import ( create_assessment_dataset, @@ -45,4 +46,5 @@ "recompute_assessment_status", "update_assessment_run_l1_stats", "update_assessment_run_status", + "update_run_post_processing_config", ] diff --git a/backend/app/crud/assessment/core.py b/backend/app/crud/assessment/core.py index 6e2c8b2f7..d5a184d06 100644 --- a/backend/app/crud/assessment/core.py +++ b/backend/app/crud/assessment/core.py @@ -5,6 +5,7 @@ from uuid import UUID from fastapi import HTTPException +from sqlalchemy.orm.attributes import flag_modified from sqlmodel import Session, select from app.core.util import now @@ -129,6 +130,30 @@ def create_assessment_run( return run +def update_run_post_processing_config( + session: Session, + run: AssessmentRun, + config: dict[str, Any] | None, +) -> AssessmentRun: + """Set post_processing_config inside the run's input JSON blob and persist.""" + run.input = {**(run.input or {}), "post_processing_config": config} + flag_modified(run, "input") + session.add(run) + try: + session.commit() + session.refresh(run) + except Exception as e: + session.rollback() + logger.error( + f"[update_run_post_processing_config] Failed for run id={run.id}: {e}", + exc_info=True, + ) + raise + + logger.info(f"[update_run_post_processing_config] Updated run id={run.id}") + return run + + def get_assessment_run_by_id( session: Session, run_id: int, diff --git a/backend/app/models/assessment.py b/backend/app/models/assessment.py index 0dd0a96d1..b8620af06 100644 --- a/backend/app/models/assessment.py +++ b/backend/app/models/assessment.py @@ -5,7 +5,7 @@ from uuid import UUID from pydantic import BaseModel, Field -from sqlalchemy import Column, Index, Text +from sqlalchemy import JSON, Column, Index, Text from sqlalchemy.dialects.postgresql import JSONB from sqlmodel import Field as SQLField from sqlmodel import Relationship, SQLModel @@ -253,6 +253,7 @@ class AssessmentRunPublic(BaseModel): l1_total_rows: int | None = None l1_total_passed: int | None = None l1_total_rejected: int | None = None + post_processing_config: dict[str, Any] | None = None inserted_at: datetime updated_at: datetime @@ -322,6 +323,13 @@ class AssessmentCreate(BaseModel): "duplicate_detection (columns). Omit to skip L1." ), ) + post_processing_config: dict[str, Any] | None = Field( + None, + description=( + "Post-processing config applied at export. " + "Keys: computed_columns, sort, filter." + ), + ) class AssessmentRunSummary(BaseModel): diff --git a/backend/app/services/assessment/service.py b/backend/app/services/assessment/service.py index d03eb13a7..cabe2bb4c 100644 --- a/backend/app/services/assessment/service.py +++ b/backend/app/services/assessment/service.py @@ -80,6 +80,7 @@ def _build_retry_request( output_schema=assessment_input.get("output_schema"), configs=configs, l1_config=assessment_input.get("l1_config"), + post_processing_config=assessment_input.get("post_processing_config"), ) @@ -121,6 +122,8 @@ def start_assessment( assessment_input["output_schema"] = request.output_schema if request.l1_config: assessment_input["l1_config"] = request.l1_config + if request.post_processing_config: + assessment_input["post_processing_config"] = request.post_processing_config config_crud = ConfigCrud(session=session, project_id=project_id) diff --git a/backend/app/services/assessment/utils/export.py b/backend/app/services/assessment/utils/export.py index 4b299151f..86d9186b0 100644 --- a/backend/app/services/assessment/utils/export.py +++ b/backend/app/services/assessment/utils/export.py @@ -239,22 +239,43 @@ def _get_prefixed(parsed: dict[str, Any] | None, col: str) -> dict[str, Any]: def serialize_export_rows( export_rows: list[AssessmentExportRow], export_format: Literal["json", "csv", "xlsx"], + post_processing_config: dict[str, Any] | None = None, ) -> tuple[bytes, str]: """Serialize export rows into the requested file format.""" + from app.services.assessment.utils.post_processing import apply_post_processing + row_payload = [row.model_dump(mode="json") for row in export_rows] if export_format == "json": expanded, *_ = _expand_output_columns(row_payload) + expanded = apply_post_processing(expanded, post_processing_config) return ( json.dumps(expanded, ensure_ascii=False, indent=2).encode("utf-8"), "application/json", ) - expanded, fieldnames, input_col_names, l1_keys, output_keys = _expand_output_columns(row_payload) + ( + expanded, + fieldnames, + input_col_names, + l1_keys, + output_keys, + ) = _expand_output_columns(row_payload) + expanded = apply_post_processing(expanded, post_processing_config) + + # Add any new computed columns to fieldnames so they appear in output + existing = set(fieldnames) + computed_names = [ + c["name"] + for c in (post_processing_config or {}).get("computed_columns") or [] + if c.get("name") and c["name"] not in existing + ] + if computed_names: + fieldnames = fieldnames + computed_names if export_format == "csv": output = io.StringIO() - writer = csv.DictWriter(output, fieldnames=fieldnames) + writer = csv.DictWriter(output, fieldnames=fieldnames, extrasaction="ignore") writer.writeheader() writer.writerows(expanded) return output.getvalue().encode("utf-8"), "text/csv" @@ -267,12 +288,11 @@ def serialize_export_rows( detail="XLSX export requires pandas/openpyxl support in the backend runtime", ) from exc - # Explicit ordering: inputs → L1 topic relevance → L1 duplicate detection → L2 output - excel_fields = input_col_names + l1_keys + output_keys + # Explicit ordering: inputs → L1 → L2 → computed columns + excel_fields = input_col_names + l1_keys + output_keys + computed_names if not excel_fields: excel_fields = output_keys or ["output"] - # Drop columns where every row is null/empty expanded, excel_fields = _drop_empty_columns(expanded, excel_fields) buf = io.BytesIO() @@ -290,17 +310,20 @@ def build_json_export_rows( ) -> list[dict[str, Any]]: """Return JSON rows with structured output expanded into top-level keys.""" row_payload = [row.model_dump(mode="json") for row in export_rows] - expanded, *_ = _expand_output_columns(row_payload) - return expanded + expanded, fieldnames, *_ = _expand_output_columns(row_payload) + return [{k: row.get(k) for k in fieldnames if k in row} for row in expanded] def build_export_response( export_rows: list[AssessmentExportRow], export_format: Literal["json", "csv", "xlsx"], base_name: str, + post_processing_config: dict[str, Any] | None = None, ) -> StreamingResponse: """Return a file download response for assessment exports.""" - payload, media_type = serialize_export_rows(export_rows, export_format) + payload, media_type = serialize_export_rows( + export_rows, export_format, post_processing_config + ) filename = generate_timestamped_filename( _safe_filename_part(base_name), extension=export_format, diff --git a/backend/app/services/assessment/utils/post_processing.py b/backend/app/services/assessment/utils/post_processing.py new file mode 100644 index 000000000..9b0d36c45 --- /dev/null +++ b/backend/app/services/assessment/utils/post_processing.py @@ -0,0 +1,184 @@ +"""Post-processing engine for assessment exports. +""" + +import ast +import logging +import operator +import re +from typing import Any + +logger = logging.getLogger(__name__) + +# Safe formula evaluator +_SAFE_OPS = { + ast.Add: operator.add, + ast.Sub: operator.sub, + ast.Mult: operator.mul, + ast.Div: operator.truediv, + ast.USub: operator.neg, +} + + +def _eval_node(node: ast.AST) -> float: + if isinstance(node, ast.Constant) and isinstance(node.value, (int, float)): + return float(node.value) + if isinstance(node, ast.BinOp) and type(node.op) in _SAFE_OPS: + return _SAFE_OPS[type(node.op)](_eval_node(node.left), _eval_node(node.right)) + if isinstance(node, ast.UnaryOp) and type(node.op) in _SAFE_OPS: + return _SAFE_OPS[type(node.op)](_eval_node(node.operand)) + raise ValueError(f"Unsupported operation in formula: {ast.dump(node)}") + + +def evaluate_formula(formula: str, row: dict[str, Any]) -> float | None: + """Evaluate a formula like '@Novelty_score + @Feasibility_score * 0.5'. + + Returns None if the formula fails or references missing columns. + """ + + def resolve(match: re.Match) -> str: + col = match.group(1) + val = row.get(col) + if val is None: + return "0" + try: + return str(float(val)) + except (TypeError, ValueError): + return "0" + + expr = re.sub(r"@([\w]+)", resolve, formula) + + try: + tree = ast.parse(expr, mode="eval") + return _eval_node(tree.body) + except Exception as exc: + logger.warning("[evaluate_formula] Failed to evaluate %r: %s", formula, exc) + return None + + +# Filter + +_FILTER_OPS = { + "eq": lambda a, b: str(a).strip().lower() == str(b).strip().lower(), + "ne": lambda a, b: str(a).strip().lower() != str(b).strip().lower(), + "contains": lambda a, b: str(b).lower() in str(a).lower(), + "not_contains": lambda a, b: str(b).lower() not in str(a).lower(), + "in": lambda a, b: str(a).strip().lower() in {str(v).lower() for v in b}, + "not_in": lambda a, b: str(a).strip().lower() not in {str(v).lower() for v in b}, + "is_empty": lambda a, _: a is None or str(a).strip() == "", + "is_not_empty": lambda a, _: a is not None and str(a).strip() != "", +} + + +def _numeric_filter(op: str, a: Any, b: Any) -> bool: + try: + fa, fb = float(a), float(b) + if op == "gt": + return fa > fb + if op == "lt": + return fa < fb + if op == "gte": + return fa >= fb + if op == "lte": + return fa <= fb + except (TypeError, ValueError): + pass + return False + + +def _row_matches_filter(row: dict[str, Any], rule: dict[str, Any]) -> bool: + col = rule["column"] + op = rule["op"] + value = rule.get("value") + cell = row.get(col) + + if op in ("gt", "lt", "gte", "lte"): + return _numeric_filter(op, cell, value) + if op in _FILTER_OPS: + return _FILTER_OPS[op](cell, value) + return True + + +def apply_computed_columns( + rows: list[dict[str, Any]], + computed_columns: list[dict[str, Any]], +) -> None: + """Add computed columns to each row in-place.""" + for row in rows: + for col_def in computed_columns: + name = col_def.get("name", "").strip() + formula = col_def.get("formula", "").strip() + if not name or not formula: + continue + row[name] = evaluate_formula(formula, row) + + +def apply_filter( + rows: list[dict[str, Any]], + filter_rules: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Return only rows that match ALL filter rules (AND logic).""" + if not filter_rules: + return rows + return [ + row + for row in rows + if all(_row_matches_filter(row, rule) for rule in filter_rules) + ] + + +def apply_sort( + rows: list[dict[str, Any]], + sort_rules: list[dict[str, Any]], +) -> list[dict[str, Any]]: + """Sort rows by priority-ordered rules. First rule has highest priority.""" + if not sort_rules: + return rows + + # Build sort key: iterate rules in reverse (lowest priority first) + # so that highest priority rule is the final (dominant) tiebreaker. + result = rows + for rule in reversed(sort_rules): + col = rule.get("column", "") + desc = str(rule.get("direction", "asc")).lower() == "desc" + + def sort_key(row: dict[str, Any], _col: str = col) -> tuple: + val = row.get(_col) + if val is None: + return (1, 0, "") + try: + return (0, -float(val) if desc else float(val), "") + except (TypeError, ValueError): + s = str(val).lower() + return ( + (0, 0, s) + if not desc + else (0, 0, "".join(chr(0x10FFFF - ord(c)) for c in s)) + ) + + result = sorted(result, key=sort_key) + + return result + + +def apply_post_processing( + rows: list[dict[str, Any]], + config: dict[str, Any] | None, +) -> list[dict[str, Any]]: + """Apply full post-processing pipeline: computed columns → filter → sort. + + Safe to call with config=None (no-op). + """ + if not config: + return rows + + computed_columns = config.get("computed_columns") or [] + filter_rules = config.get("filter") or [] + sort_rules = config.get("sort") or [] + + if computed_columns: + apply_computed_columns(rows, computed_columns) + + rows = apply_filter(rows, filter_rules) + rows = apply_sort(rows, sort_rules) + + return rows From c1791d5f44e7f6be10827b6d3dcbdfd3e5b28565 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 2 Jun 2026 09:30:09 +0530 Subject: [PATCH 4/8] feat(assessment): Enhance attachment handling in L1 pipeline with mixed type detection and improved utility functions --- backend/app/crud/assessment/batch.py | 71 +------ backend/app/models/assessment.py | 9 +- .../app/services/assessment/l1/pipeline.py | 15 +- .../services/assessment/l1/topic_relevance.py | 34 ++- backend/app/services/assessment/tasks.py | 10 +- .../services/assessment/utils/attachments.py | 193 +++++++++++++++++- backend/app/tests/assessment/test_batch.py | 104 ++++++++++ backend/app/tests/assessment/test_export.py | 12 +- .../tests/assessment/test_topic_relevance.py | 123 +++++++++++ 9 files changed, 487 insertions(+), 84 deletions(-) create mode 100644 backend/app/tests/assessment/test_topic_relevance.py diff --git a/backend/app/crud/assessment/batch.py b/backend/app/crud/assessment/batch.py index e5e52daec..531dc038d 100644 --- a/backend/app/crud/assessment/batch.py +++ b/backend/app/crud/assessment/batch.py @@ -30,11 +30,8 @@ normalize_llm_text, ) from app.services.assessment.utils.attachments import ( + build_gemini_attachment_parts, resolve_attachment_values, - resolve_image_mime_and_payload, - split_attachment_urls, - split_data_url, - to_direct_attachment_url, ) from app.services.llm.providers.registry import LLMProvider @@ -174,6 +171,8 @@ def build_openai_jsonl( } """ jsonl_data = [] + # Memoize per-item type probes across all rows in this build. + type_cache: dict[str, str] = {} for i, row in enumerate(rows): idx = row_indices[i] if row_indices is not None else i @@ -188,7 +187,7 @@ def build_openai_jsonl( # Attachments for att in attachments: cell_value = row.get(att.column, "") - input_parts.extend(resolve_attachment_values(cell_value, att)) + input_parts.extend(resolve_attachment_values(cell_value, att, type_cache)) if not input_parts: logger.warning("[build_openai_jsonl] Skipping empty row | idx=%s", idx) @@ -232,6 +231,8 @@ def build_google_jsonl( } """ jsonl_data = [] + # Memoize per-item type probes across all rows in this build. + type_cache: dict[str, str] = {} for i, row in enumerate(rows): idx = row_indices[i] if row_indices is not None else i @@ -244,64 +245,8 @@ def build_google_jsonl( # Attachments (Gemini uses file_data for inline content) for att in attachments: - cell_value = row.get(att.column, "").strip() - if not cell_value: - continue - - cell_values = ( - split_attachment_urls(cell_value) - if att.format == "url" - else [cell_value] - ) - - for item_value in cell_values: - normalized_value = ( - to_direct_attachment_url(item_value, att.type) - if att.format == "url" - else item_value - ) - if att.type == "image": - mime_type, payload = resolve_image_mime_and_payload( - normalized_value, - att.format, - ) - if att.format == "url": - parts.append( - { - "fileData": { - "mimeType": mime_type, - "fileUri": normalized_value, - } - } - ) - else: - parts.append( - { - "inlineData": { - "mimeType": mime_type, - "data": payload, - } - } - ) - elif att.type == "pdf": - if att.format == "url": - parts.append( - { - "fileData": { - "mimeType": "application/pdf", - "fileUri": normalized_value, - } - } - ) - else: - parts.append( - { - "inlineData": { - "mimeType": "application/pdf", - "data": split_data_url(normalized_value)[1], - } - } - ) + cell_value = row.get(att.column, "") + parts.extend(build_gemini_attachment_parts(cell_value, att, type_cache)) if not parts: logger.warning("[build_google_jsonl] Skipping empty row | idx=%s", idx) diff --git a/backend/app/models/assessment.py b/backend/app/models/assessment.py index b8620af06..b5a1a31f5 100644 --- a/backend/app/models/assessment.py +++ b/backend/app/models/assessment.py @@ -275,7 +275,14 @@ class AssessmentAttachment(BaseModel): """Attachment column configuration.""" column: str = Field(..., description="Column name containing the attachment data") - type: Literal["image", "pdf"] = Field(..., description="Attachment type") + type: Literal["image", "pdf", "mixed"] = Field( + ..., + description=( + "Attachment type. 'mixed' detects image vs pdf per item (for columns " + "that contain both); 'image'/'pdf' force a type and act as fallback " + "when per-item detection is inconclusive." + ), + ) format: Literal["url", "base64"] = Field(..., description="Data format") diff --git a/backend/app/services/assessment/l1/pipeline.py b/backend/app/services/assessment/l1/pipeline.py index 2a002e5e5..18df91324 100644 --- a/backend/app/services/assessment/l1/pipeline.py +++ b/backend/app/services/assessment/l1/pipeline.py @@ -17,7 +17,7 @@ from app.core.config import settings from app.core.cloud import get_cloud_storage from app.core.storage_utils import upload_jsonl_to_object_store -from app.models.assessment import AssessmentRun +from app.models.assessment import AssessmentAttachment, AssessmentRun from app.services.assessment.l1.duplicate_detection import run_duplicate_detection from app.services.assessment.l1.topic_relevance import run_topic_relevance @@ -50,6 +50,7 @@ def run_l1_pipeline( session: Session, organization_id: int, project_id: int, + attachments: list[AssessmentAttachment] | None = None, ) -> tuple[list[dict[str, str]], list[int], list[dict[str, Any]]]: """Run L1 filters on all rows. @@ -78,6 +79,13 @@ def run_l1_pipeline( tr_prompt: str = tr_config.get("prompt") or "" dup_columns: list[str] = dup_config.get("columns") or [] + tr_attachment_columns = tr_config.get("attachment_columns") + if tr_attachment_columns is None: + tr_attachments = list(attachments or []) + else: + selected = set(tr_attachment_columns) + tr_attachments = [a for a in (attachments or []) if a.column in selected] + tr_enabled = bool(tr_columns and tr_prompt) dup_enabled = bool(dup_columns) @@ -105,6 +113,9 @@ def run_l1_pipeline( ) # tr_results[idx] = None when TR disabled → no topic_relevance columns in export + # Shared across rows so each unique attachment file is type-probed once. + attachment_type_cache: dict[str, str] = {} + tr_results: dict[int, dict[str, Any] | None] = {} if tr_enabled: with ThreadPoolExecutor(max_workers=workers) as executor: @@ -117,6 +128,8 @@ def run_l1_pipeline( tr_prompt, gemini_client, model, + tr_attachments, + attachment_type_cache, ): idx for idx, row in enumerate(rows) } diff --git a/backend/app/services/assessment/l1/topic_relevance.py b/backend/app/services/assessment/l1/topic_relevance.py index 42516d27b..c1894c04e 100644 --- a/backend/app/services/assessment/l1/topic_relevance.py +++ b/backend/app/services/assessment/l1/topic_relevance.py @@ -8,6 +8,9 @@ from google import genai from google.genai import types +from app.models.assessment import AssessmentAttachment +from app.services.assessment.utils.attachments import build_gemini_attachment_parts + logger = logging.getLogger(__name__) @@ -45,21 +48,42 @@ def run_topic_relevance( user_prompt: str, gemini_client: genai.Client, model: str, + attachments: list[AssessmentAttachment] | None = None, + type_cache: dict[str, str] | None = None, ) -> dict[str, Any]: """Run topic relevance check on a single row. System instruction = user_prompt (the evaluation rubric/criteria). - User content = dict of {column_name: value} for the selected columns. + User content = the selected columns as JSON plus every mapped attachment + (image/pdf) for the row, so relevance is judged on text and documents. + Each attachment column also gets its own relevance boolean in the schema, + so the export carries a ``topic_relevance_`` column. Output schema enforced: decision (ACCEPT/REJECT) + reasoning. On error defaults to verdict=True (fail-open). """ + # Document columns that actually have a value for this row. + doc_columns: list[str] = [] + for att in attachments or []: + if att.column not in doc_columns and (row.get(att.column) or "").strip(): + doc_columns.append(att.column) + + schema_columns = columns + doc_columns user_content = json.dumps({col: row.get(col, "") or "" for col in columns}) - output_schema = _build_output_schema(columns) + output_schema = _build_output_schema(schema_columns) + + parts: list[dict[str, Any]] = [{"text": user_content}] + for att in attachments or []: + attachment_parts = build_gemini_attachment_parts( + row.get(att.column, ""), att, type_cache + ) + if attachment_parts: + parts.append({"text": f"Attached document(s) for column '{att.column}':"}) + parts.extend(attachment_parts) try: response = gemini_client.models.generate_content( model=model, - contents=user_content, + contents=[{"role": "user", "parts": parts}], config=types.GenerateContentConfig( system_instruction=user_prompt.strip(), response_mime_type="application/json", @@ -70,7 +94,7 @@ def run_topic_relevance( raw = (response.text or "").strip() parsed = json.loads(raw) decision = str(parsed.get("decision", "ACCEPT")).upper() - column_relevance = {col: bool(parsed.get(col, True)) for col in columns} + column_relevance = {col: bool(parsed.get(col, True)) for col in schema_columns} return { "row_id": f"row_{row_idx}", "verdict": decision == "ACCEPT", @@ -88,6 +112,6 @@ def run_topic_relevance( "row_id": f"row_{row_idx}", "verdict": True, "decision": "ACCEPT", - "column_relevance": {col: True for col in columns}, + "column_relevance": {col: True for col in schema_columns}, "reasoning": f"(evaluation error — defaulting to pass) {exc}", } diff --git a/backend/app/services/assessment/tasks.py b/backend/app/services/assessment/tasks.py index 66f644050..295c55ad2 100644 --- a/backend/app/services/assessment/tasks.py +++ b/backend/app/services/assessment/tasks.py @@ -13,7 +13,11 @@ from app.crud.assessment.batch import _load_dataset_rows, submit_assessment_batch from app.crud.config import ConfigCrud from app.crud.evaluations.core import resolve_evaluation_config -from app.models.assessment import Assessment, AssessmentRun +from app.models.assessment import ( + Assessment, + AssessmentAttachment, + AssessmentRun, +) from app.models.config.config import ConfigTag from app.services.assessment.l1 import run_l1_pipeline @@ -128,6 +132,10 @@ def execute_assessment_run( session=session, organization_id=organization_id, project_id=project_id, + attachments=[ + AssessmentAttachment(**a) + for a in assessment_input.get("attachments") or [] + ], ) logger.info( "[execute_assessment_run] L1 done | run_id=%s | rows_to_l2=%s / %s", diff --git a/backend/app/services/assessment/utils/attachments.py b/backend/app/services/assessment/utils/attachments.py index 5a141a757..3622f9bce 100644 --- a/backend/app/services/assessment/utils/attachments.py +++ b/backend/app/services/assessment/utils/attachments.py @@ -6,12 +6,17 @@ import base64 import binascii +import logging import re from typing import Any from urllib.parse import urlparse +import requests + from app.models.assessment import AssessmentAttachment +logger = logging.getLogger(__name__) + _IMAGE_MIME_BY_EXT = { ".png": "image/png", ".jpg": "image/jpeg", @@ -92,10 +97,8 @@ def _decode_base64_prefix(payload: str, max_chars: int = 256) -> bytes | None: return None -def _guess_image_mime_from_base64(payload: str) -> str | None: - blob = _decode_base64_prefix(payload) - if not blob: - return None +def _image_mime_from_magic(blob: bytes) -> str | None: + """Detect image mime type from leading magic bytes.""" if blob.startswith(b"\x89PNG\r\n\x1a\n"): return "image/png" if blob.startswith(b"\xff\xd8\xff"): @@ -111,6 +114,22 @@ def _guess_image_mime_from_base64(payload: str) -> str | None: return None +def _guess_image_mime_from_base64(payload: str) -> str | None: + blob = _decode_base64_prefix(payload) + if not blob: + return None + return _image_mime_from_magic(blob) + + +def _type_from_magic(blob: bytes) -> str | None: + """Detect 'image' or 'pdf' from leading magic bytes; None if neither.""" + if blob.startswith(b"%PDF"): + return "pdf" + if _image_mime_from_magic(blob): + return "image" + return None + + def resolve_image_mime_and_payload( value: str, format_type: str, @@ -126,9 +145,110 @@ def resolve_image_mime_and_payload( return _guess_image_mime_from_base64(payload) or "image/png", payload +def _drive_file_id(url: str) -> str | None: + """Extract a Google Drive file id from common share URL shapes.""" + match = re.match(r"https://drive\.google\.com/file/d/([^/]+)", url) + if match: + return match.group(1) + match = re.search(r"[?&]id=([a-zA-Z0-9_-]+)", url) + if match and ("drive.google.com" in url or "drive.usercontent.google.com" in url): + return match.group(1) + return None + + +def _type_from_url_extension(url: str) -> str | None: + """Detect 'image' or 'pdf' from a URL path extension; None if unknown.""" + path = (urlparse(url).path or "").lower() + if path.endswith(".pdf"): + return "pdf" + if _guess_image_mime_from_url(url): + return "image" + return None + + +def _type_from_content_type(content_type: str | None) -> str | None: + if not content_type: + return None + content_type = content_type.split(";")[0].strip().lower() + if content_type == "application/pdf": + return "pdf" + if content_type.startswith("image/"): + return "image" + return None + + +def _probe_url_type(url: str, num_bytes: int = 16) -> str | None: + """Probe a remote URL's type: ranged byte sniff first, Content-Type fallback. + + Reads only the first few bytes (does not download the whole file). Drive + share URLs are routed through the download endpoint so the real file bytes + are read instead of an HTML share page. + """ + file_id = _drive_file_id(url) + probe_url = ( + f"https://drive.google.com/uc?export=download&id={file_id}" if file_id else url + ) + + try: + with requests.get( + probe_url, + headers={"Range": f"bytes=0-{num_bytes - 1}"}, + timeout=10, + stream=True, + allow_redirects=True, + ) as resp: + resp.raise_for_status() + for chunk in resp.iter_content(chunk_size=num_bytes): + magic_type = _type_from_magic(chunk) + if magic_type: + return magic_type + break + return _type_from_content_type(resp.headers.get("Content-Type")) + except requests.RequestException as e: + logger.warning(f"[_probe_url_type] Probe failed for {url}: {e}") + return None + + +def detect_item_type( + value: str, + format_type: str, + fallback: str, + cache: dict[str, str] | None = None, +) -> str: + """Resolve a single attachment item as 'image' or 'pdf'. + + Order: data-URL/base64 magic (no network) -> URL extension -> remote probe + (ranged byte sniff, then Content-Type) -> declared ``fallback`` type. + ``fallback`` may be 'mixed'; when detection is inconclusive it resolves to + 'image'. Remote probe results are memoized in ``cache`` keyed by item value. + """ + # 'mixed' is not a concrete output type; terminal default is image. + safe_fallback = fallback if fallback in ("image", "pdf") else "image" + + if format_type != "url": + data_url_mime, payload = split_data_url(value) + if data_url_mime == "application/pdf": + return "pdf" + if data_url_mime and data_url_mime.startswith("image/"): + return "image" + blob = _decode_base64_prefix(payload) + return (_type_from_magic(blob) if blob else None) or safe_fallback + + if cache is not None and value in cache: + return cache[value] + + item_type = ( + _type_from_url_extension(value) or _probe_url_type(value) or safe_fallback + ) + if cache is not None: + cache[value] = item_type + return item_type + + def resolve_attachment_values( value: str, att: AssessmentAttachment, + type_cache: dict[str, str] | None = None, ) -> list[dict[str, Any]]: """Convert one dataset cell into one or more OpenAI-style input objects.""" value = value.strip() @@ -142,13 +262,14 @@ def resolve_attachment_values( resolved: list[dict[str, Any]] = [] for item_value in values: + item_type = detect_item_type(item_value, att.format, att.type, type_cache) normalized_value = ( - to_direct_attachment_url(item_value, att.type) + to_direct_attachment_url(item_value, item_type) if att.format == "url" else item_value ) - if att.type == "image": + if item_type == "image": if att.format == "url": resolved.append({"type": "input_image", "image_url": normalized_value}) else: @@ -162,7 +283,7 @@ def resolve_attachment_values( "image_url": f"data:{mime_type};base64,{payload}", } ) - elif att.type == "pdf": + elif item_type == "pdf": if att.format == "url": resolved.append( { @@ -181,3 +302,61 @@ def resolve_attachment_values( ) return resolved + + +def build_gemini_attachment_parts( + value: str, + att: AssessmentAttachment, + type_cache: dict[str, str] | None = None, +) -> list[dict[str, Any]]: + """Convert one dataset cell into one or more Gemini content parts. + + Mirrors the per-item type detection used for the L2 batch so the same + image/pdf routing applies to L1 (topic relevance) calls. + """ + value = value.strip() + if not value: + return [] + + values = split_attachment_urls(value) if att.format == "url" else [value] + + parts: list[dict[str, Any]] = [] + for item_value in values: + item_type = detect_item_type(item_value, att.format, att.type, type_cache) + normalized_value = ( + to_direct_attachment_url(item_value, item_type) + if att.format == "url" + else item_value + ) + + if item_type == "image": + mime_type, payload = resolve_image_mime_and_payload( + normalized_value, att.format + ) + if att.format == "url": + parts.append( + {"fileData": {"mimeType": mime_type, "fileUri": normalized_value}} + ) + else: + parts.append({"inlineData": {"mimeType": mime_type, "data": payload}}) + elif item_type == "pdf": + if att.format == "url": + parts.append( + { + "fileData": { + "mimeType": "application/pdf", + "fileUri": normalized_value, + } + } + ) + else: + parts.append( + { + "inlineData": { + "mimeType": "application/pdf", + "data": split_data_url(normalized_value)[1], + } + } + ) + + return parts diff --git a/backend/app/tests/assessment/test_batch.py b/backend/app/tests/assessment/test_batch.py index 6d524e81f..41d84198d 100644 --- a/backend/app/tests/assessment/test_batch.py +++ b/backend/app/tests/assessment/test_batch.py @@ -1,5 +1,6 @@ """Tests for assessment/batch.py provider routing in submit_assessment_batch.""" +import base64 import io from types import SimpleNamespace from unittest.mock import MagicMock, patch @@ -21,6 +22,7 @@ _decode_base64_prefix, _guess_image_mime_from_base64, _guess_image_mime_from_url, + detect_item_type, resolve_attachment_values, resolve_image_mime_and_payload, split_attachment_urls, @@ -423,3 +425,105 @@ def test_build_openai_and_google_jsonl(self) -> None: assert google_jsonl[0]["request"]["systemInstruction"] == { "parts": [{"text": "system"}] } + + +class TestDetectItemType: + """Per-item image/pdf detection for mixed-content attachment columns.""" + + def test_data_url_pdf(self) -> None: + assert ( + detect_item_type("data:application/pdf;base64,JVBERi0=", "base64", "image") + == "pdf" + ) + + def test_data_url_image(self) -> None: + assert ( + detect_item_type("data:image/png;base64,AAAA", "base64", "pdf") == "image" + ) + + def test_base64_magic_pdf(self) -> None: + payload = base64.b64encode(b"%PDF-1.7 body").decode() + assert detect_item_type(payload, "base64", "image") == "pdf" + + def test_base64_magic_png(self) -> None: + payload = base64.b64encode(b"\x89PNG\r\n\x1a\n" + b"0" * 8).decode() + assert detect_item_type(payload, "base64", "pdf") == "image" + + def test_base64_unknown_falls_back(self) -> None: + payload = base64.b64encode(b"not a known magic").decode() + assert detect_item_type(payload, "base64", "pdf") == "pdf" + + def test_mixed_fallback_resolves_to_image(self) -> None: + """'mixed' is never a returned type; inconclusive detection -> image.""" + payload = base64.b64encode(b"not a known magic").decode() + assert detect_item_type(payload, "base64", "mixed") == "image" + + def test_url_extension_pdf_case_insensitive(self) -> None: + assert detect_item_type("https://x.com/a/scan.PDF", "url", "image", {}) == "pdf" + + def test_url_extension_image(self) -> None: + assert detect_item_type("https://x.com/a/p.jpg", "url", "pdf", {}) == "image" + + def test_url_no_extension_probes_bytes(self) -> None: + """Extensionless URL (Drive-style) is probed; magic bytes win over fallback.""" + url = "https://drive.google.com/file/d/ABC123/view" + resp = MagicMock() + resp.__enter__ = MagicMock(return_value=resp) + resp.__exit__ = MagicMock(return_value=False) + resp.raise_for_status = MagicMock() + resp.iter_content = MagicMock(return_value=iter([b"%PDF-1.7"])) + with patch( + "app.services.assessment.utils.attachments.requests.get", + return_value=resp, + ) as mock_get: + assert detect_item_type(url, "url", "image", {}) == "pdf" + # Drive share URL is probed through the download endpoint. + assert "uc?export=download&id=ABC123" in mock_get.call_args.args[0] + + def test_url_probe_uses_content_type_when_no_magic(self) -> None: + url = "https://example.com/file" + resp = MagicMock() + resp.__enter__ = MagicMock(return_value=resp) + resp.__exit__ = MagicMock(return_value=False) + resp.raise_for_status = MagicMock() + resp.iter_content = MagicMock(return_value=iter([b"\x00\x01\x02\x03"])) + resp.headers = {"Content-Type": "application/pdf; charset=binary"} + with patch( + "app.services.assessment.utils.attachments.requests.get", + return_value=resp, + ): + assert detect_item_type(url, "url", "image", {}) == "pdf" + + def test_url_probe_failure_falls_back(self) -> None: + import requests as _requests + + url = "https://example.com/file" + with patch( + "app.services.assessment.utils.attachments.requests.get", + side_effect=_requests.RequestException("boom"), + ): + assert detect_item_type(url, "url", "image", {}) == "image" + + def test_cache_skips_second_probe(self) -> None: + url = "https://drive.google.com/file/d/XYZ/view" + cache: dict[str, str] = {} + resp = MagicMock() + resp.__enter__ = MagicMock(return_value=resp) + resp.__exit__ = MagicMock(return_value=False) + resp.raise_for_status = MagicMock() + resp.iter_content = MagicMock(return_value=iter([b"%PDF-1.7"])) + with patch( + "app.services.assessment.utils.attachments.requests.get", + return_value=resp, + ) as mock_get: + assert detect_item_type(url, "url", "image", cache) == "pdf" + assert detect_item_type(url, "url", "image", cache) == "pdf" + assert mock_get.call_count == 1 + + def test_mixed_column_resolves_both_types(self) -> None: + """One column, two URLs with extensions -> one image, one pdf object.""" + att = AssessmentAttachment(column="docs", type="image", format="url") + value = "https://x.com/a/photo.jpg, https://x.com/b/report.pdf" + resolved = resolve_attachment_values(value, att, {}) + types = [obj["type"] for obj in resolved] + assert types == ["input_image", "input_file"] diff --git a/backend/app/tests/assessment/test_export.py b/backend/app/tests/assessment/test_export.py index 3ace89dbd..98eb10683 100644 --- a/backend/app/tests/assessment/test_export.py +++ b/backend/app/tests/assessment/test_export.py @@ -144,14 +144,14 @@ def test_all_empty_drops_all(self) -> None: class TestExpandOutputColumns: def test_plain_string_output_not_expanded(self) -> None: rows = [{"output": "plain text", "input_data": None}] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) assert "output" in fieldnames def test_json_dict_output_expanded(self) -> None: rows = [ {"output": json.dumps({"score": 5, "reason": "good"}), "input_data": None} ] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) assert "score" in fieldnames assert "reason" in fieldnames assert expanded[0]["score"] == 5 @@ -161,14 +161,14 @@ def test_mixed_parsed_and_unparsed_adds_output_raw(self) -> None: {"output": json.dumps({"score": 3}), "input_data": None}, {"output": "not json", "input_data": None}, ] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) assert "output_raw" in fieldnames # Second row that didn't parse should get output_raw assert expanded[1].get("output_raw") == "not json" def test_none_output_handled(self) -> None: rows = [{"output": None, "input_data": None}] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) assert expanded[0].get("output") is None @@ -253,13 +253,13 @@ class TestExpandOutputColumnsDictOutput: def test_dict_output_expanded_directly(self) -> None: # raw output is already a dict (not a JSON string) rows = [{"output": {"score": 9, "label": "good"}, "input_data": None}] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) assert "score" in fieldnames assert expanded[0]["score"] == 9 def test_non_dict_non_string_output_treated_as_unparsed(self) -> None: rows = [{"output": 42, "input_data": None}] - expanded, fieldnames = _expand_output_columns(rows) + expanded, fieldnames, *_ = _expand_output_columns(rows) # 42 is not a dict/string, treated as unparsed → output stays as-is assert "output" in fieldnames diff --git a/backend/app/tests/assessment/test_topic_relevance.py b/backend/app/tests/assessment/test_topic_relevance.py new file mode 100644 index 000000000..ad52c2306 --- /dev/null +++ b/backend/app/tests/assessment/test_topic_relevance.py @@ -0,0 +1,123 @@ +"""Tests for L1 topic relevance attachment handling.""" + +import json +from unittest.mock import MagicMock + +from app.models.assessment import AssessmentAttachment +from app.services.assessment.l1.topic_relevance import run_topic_relevance + + +def _client_returning(decision: str) -> MagicMock: + client = MagicMock() + response = MagicMock() + response.text = json.dumps( + {"decision": decision, "Problem": True, "reasoning": "ok"} + ) + client.models.generate_content.return_value = response + return client + + +class TestTopicRelevanceAttachments: + def test_attachments_added_to_contents(self) -> None: + client = _client_returning("ACCEPT") + att = AssessmentAttachment(column="Documents", type="image", format="url") + row = {"Problem": "p", "Documents": "https://x.com/a/photo.jpg"} + + result = run_topic_relevance( + row_idx=0, + row=row, + columns=["Problem"], + user_prompt="rubric", + gemini_client=client, + model="gemini-2.5-flash", + attachments=[att], + type_cache={}, + ) + + assert result["verdict"] is True + contents = client.models.generate_content.call_args.kwargs["contents"] + parts = contents[0]["parts"] + # First part is the text JSON, then a label, then the attachment file part. + assert parts[0]["text"] + file_parts = [p for p in parts if "fileData" in p] + assert len(file_parts) == 1 + assert file_parts[0]["fileData"]["fileUri"] == "https://x.com/a/photo.jpg" + + def test_document_relevance_in_schema_and_result(self) -> None: + """Selected doc column gets its own relevance boolean in column_relevance.""" + client = MagicMock() + response = MagicMock() + response.text = json.dumps( + { + "decision": "ACCEPT", + "Problem": True, + "Documents": True, + "reasoning": "ok", + } + ) + client.models.generate_content.return_value = response + att = AssessmentAttachment(column="Documents", type="image", format="url") + row = {"Problem": "p", "Documents": "https://x.com/a/photo.jpg"} + + result = run_topic_relevance( + row_idx=3, + row=row, + columns=["Problem"], + user_prompt="rubric", + gemini_client=client, + model="gemini-2.5-flash", + attachments=[att], + type_cache={}, + ) + + # Document column carried into the per-column relevance map -> exports + # as topic_relevance_Documents. + assert "Documents" in result["column_relevance"] + assert result["column_relevance"]["Documents"] is True + schema = client.models.generate_content.call_args.kwargs[ + "config" + ].response_schema + assert "Documents" in schema["properties"] + + def test_no_attachments_text_only(self) -> None: + client = _client_returning("REJECT") + row = {"Problem": "p"} + + result = run_topic_relevance( + row_idx=1, + row=row, + columns=["Problem"], + user_prompt="rubric", + gemini_client=client, + model="gemini-2.5-flash", + ) + + assert result["verdict"] is False + contents = client.models.generate_content.call_args.kwargs["contents"] + parts = contents[0]["parts"] + assert len(parts) == 1 + assert parts[0]["text"] + + def test_mixed_column_pdf_item_detected(self) -> None: + client = _client_returning("ACCEPT") + att = AssessmentAttachment(column="Documents", type="mixed", format="url") + row = {"Problem": "p", "Documents": "https://x.com/a/report.pdf"} + + run_topic_relevance( + row_idx=2, + row=row, + columns=["Problem"], + user_prompt="rubric", + gemini_client=client, + model="gemini-2.5-flash", + attachments=[att], + type_cache={}, + ) + + parts = client.models.generate_content.call_args.kwargs["contents"][0]["parts"] + pdf_parts = [ + p + for p in parts + if p.get("fileData", {}).get("mimeType") == "application/pdf" + ] + assert len(pdf_parts) == 1 From 98acf866904d67fa911bb1b659e04fe3d1e5f678 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 2 Jun 2026 11:28:50 +0530 Subject: [PATCH 5/8] feat(tests): update assessment run status to 'l2_processing' and refactor batch submission to Celery task --- backend/app/tests/assessment/test_cron.py | 6 +-- backend/app/tests/assessment/test_crud.py | 3 ++ backend/app/tests/assessment/test_service.py | 53 +++++++------------- 3 files changed, 23 insertions(+), 39 deletions(-) diff --git a/backend/app/tests/assessment/test_cron.py b/backend/app/tests/assessment/test_cron.py index c9407bd5c..d9e8527eb 100644 --- a/backend/app/tests/assessment/test_cron.py +++ b/backend/app/tests/assessment/test_cron.py @@ -115,7 +115,7 @@ async def test_active_run_processed(self) -> None: session = MagicMock() assessment = _make_assessment(id=1, status="processing") run = _make_run(id=11) - run.status = "processing" + run.status = "l2_processing" session.exec.return_value.all.return_value = [assessment] with patch( @@ -140,7 +140,7 @@ async def test_active_run_failure_and_cleanup_failure(self) -> None: session = MagicMock() assessment = _make_assessment(id=1, status="processing") run = _make_run(id=11) - run.status = "processing" + run.status = "l2_processing" session.exec.return_value.all.return_value = [assessment] with patch( @@ -164,7 +164,7 @@ async def test_active_run_failure_updates_db_with_same_error_message(self) -> No session = MagicMock() assessment = _make_assessment(id=1, status="processing") run = _make_run(id=11) - run.status = "processing" + run.status = "l2_processing" session.exec.return_value.all.return_value = [assessment] with patch( diff --git a/backend/app/tests/assessment/test_crud.py b/backend/app/tests/assessment/test_crud.py index 2bc076342..1d456329e 100644 --- a/backend/app/tests/assessment/test_crud.py +++ b/backend/app/tests/assessment/test_crud.py @@ -232,6 +232,9 @@ def test_build_run_stats(self) -> None: total_items=2, error_message=None, updated_at=datetime(2024, 1, 1), + l1_total_rows=None, + l1_total_passed=None, + l1_total_rejected=None, ), ] stats = build_run_stats(runs) diff --git a/backend/app/tests/assessment/test_service.py b/backend/app/tests/assessment/test_service.py index b3654fa9b..e3d46ef55 100644 --- a/backend/app/tests/assessment/test_service.py +++ b/backend/app/tests/assessment/test_service.py @@ -142,9 +142,6 @@ def test_google_provider_is_supported(self) -> None: config_blob = SimpleNamespace( completion=SimpleNamespace(provider="google", params={"model": "gemini"}) ) - batch_job = MagicMock() - batch_job.id = 101 - batch_job.total_items = 3 with ( patch( @@ -164,13 +161,8 @@ def test_google_provider_is_supported(self) -> None: return_value=run, ), patch( - "app.services.assessment.service.submit_assessment_batch", - return_value=batch_job, - ) as submit_batch, - patch( - "app.services.assessment.service.update_assessment_run_status", - return_value=run, - ), + "app.celery.tasks.job_execution.run_assessment_run" + ) as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): @@ -181,8 +173,10 @@ def test_google_provider_is_supported(self) -> None: project_id=1, ) + # Google is an accepted provider — no rejection, one Celery task dispatched. assert response.num_configs == 1 - assert submit_batch.call_args.kwargs["config_blob"] is config_blob + dispatch.delay.assert_called_once() + assert dispatch.delay.call_args.kwargs["run_id"] == 11 def test_defaults_missing_provider_to_openai(self) -> None: session = MagicMock() @@ -194,9 +188,6 @@ def test_defaults_missing_provider_to_openai(self) -> None: config_blob = SimpleNamespace( completion=SimpleNamespace(provider=None, params={"model": "gpt-4.1-mini"}) ) - batch_job = MagicMock() - batch_job.id = 101 - batch_job.total_items = 3 with ( patch( @@ -216,13 +207,8 @@ def test_defaults_missing_provider_to_openai(self) -> None: return_value=run, ) as create_run, patch( - "app.services.assessment.service.submit_assessment_batch", - return_value=batch_job, - ) as submit_batch, - patch( - "app.services.assessment.service.update_assessment_run_status", - return_value=run, - ), + "app.celery.tasks.job_execution.run_assessment_run" + ) as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): @@ -238,11 +224,7 @@ def test_defaults_missing_provider_to_openai(self) -> None: assert response.runs[0].run_id == 11 assessment_input = create_run.call_args.kwargs["assessment_input"] assert assessment_input["system_instruction"] == "Assess strictly" - assert ( - submit_batch.call_args.kwargs["assessment_input"]["system_instruction"] - == "Assess strictly" - ) - submit_batch.assert_called_once() + dispatch.delay.assert_called_once() def test_rejects_default_tagged_config(self) -> None: """Configs explicitly tagged 'default' must be rejected for assessment.""" @@ -278,14 +260,15 @@ def test_rejects_default_tagged_config(self) -> None: # Tag check must fire BEFORE config resolution. resolve.assert_not_called() - def test_batch_submission_failure_marks_run_failed(self) -> None: + def test_dispatches_one_celery_task_per_config(self) -> None: + """Batch submission moved to the Celery task; start_assessment only + creates runs and dispatches one task per resolved config.""" session = MagicMock() request = _make_request(UUID("00000000-0000-0000-0000-000000000001")) dataset = _make_dataset() assessment = MagicMock() assessment.id = 21 run = _make_run() - run.status = "failed" config_blob = SimpleNamespace( completion=SimpleNamespace( provider="openai", params={"model": "gpt-4.1-mini"} @@ -310,13 +293,8 @@ def test_batch_submission_failure_marks_run_failed(self) -> None: return_value=run, ), patch( - "app.services.assessment.service.submit_assessment_batch", - side_effect=RuntimeError("submit failed"), - ), - patch( - "app.services.assessment.service.update_assessment_run_status", - return_value=run, - ) as update_run, + "app.celery.tasks.job_execution.run_assessment_run" + ) as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): @@ -327,7 +305,10 @@ def test_batch_submission_failure_marks_run_failed(self) -> None: project_id=1, ) assert response.num_configs == 1 - assert update_run.called + dispatch.delay.assert_called_once() + assert dispatch.delay.call_args.kwargs["run_id"] == 11 + assert dispatch.delay.call_args.kwargs["organization_id"] == 1 + assert dispatch.delay.call_args.kwargs["project_id"] == 1 class TestRetryHelpers: From 0addb717f761bd2aab5895d138db4e259f03a715 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 2 Jun 2026 11:37:11 +0530 Subject: [PATCH 6/8] refactor(tests): streamline patching of run_assessment_run in TestStartAssessment --- backend/app/tests/assessment/test_service.py | 12 +++--------- 1 file changed, 3 insertions(+), 9 deletions(-) diff --git a/backend/app/tests/assessment/test_service.py b/backend/app/tests/assessment/test_service.py index e3d46ef55..b56ec72b7 100644 --- a/backend/app/tests/assessment/test_service.py +++ b/backend/app/tests/assessment/test_service.py @@ -160,9 +160,7 @@ def test_google_provider_is_supported(self) -> None: "app.services.assessment.service.create_assessment_run", return_value=run, ), - patch( - "app.celery.tasks.job_execution.run_assessment_run" - ) as dispatch, + patch("app.celery.tasks.job_execution.run_assessment_run") as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): @@ -206,9 +204,7 @@ def test_defaults_missing_provider_to_openai(self) -> None: "app.services.assessment.service.create_assessment_run", return_value=run, ) as create_run, - patch( - "app.celery.tasks.job_execution.run_assessment_run" - ) as dispatch, + patch("app.celery.tasks.job_execution.run_assessment_run") as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): @@ -292,9 +288,7 @@ def test_dispatches_one_celery_task_per_config(self) -> None: "app.services.assessment.service.create_assessment_run", return_value=run, ), - patch( - "app.celery.tasks.job_execution.run_assessment_run" - ) as dispatch, + patch("app.celery.tasks.job_execution.run_assessment_run") as dispatch, patch("app.services.assessment.service.recompute_assessment_status"), _assessment_config_crud_patch(), ): From ad8e29f7e304d6fcf7f07afdf450115ea0e7c012 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 2 Jun 2026 14:16:35 +0530 Subject: [PATCH 7/8] feat(tests): add comprehensive tests for L1 duplicate detection and pipeline orchestrator --- backend/app/tests/assessment/test_batch.py | 46 ++++ backend/app/tests/assessment/test_crud.py | 78 ++++++- .../assessment/test_duplicate_detection.py | 132 +++++++++++ backend/app/tests/assessment/test_pipeline.py | 151 +++++++++++++ .../tests/assessment/test_post_processing.py | 212 ++++++++++++++++++ 5 files changed, 618 insertions(+), 1 deletion(-) create mode 100644 backend/app/tests/assessment/test_duplicate_detection.py create mode 100644 backend/app/tests/assessment/test_pipeline.py create mode 100644 backend/app/tests/assessment/test_post_processing.py diff --git a/backend/app/tests/assessment/test_batch.py b/backend/app/tests/assessment/test_batch.py index 41d84198d..aa0fce1a0 100644 --- a/backend/app/tests/assessment/test_batch.py +++ b/backend/app/tests/assessment/test_batch.py @@ -527,3 +527,49 @@ def test_mixed_column_resolves_both_types(self) -> None: resolved = resolve_attachment_values(value, att, {}) types = [obj["type"] for obj in resolved] assert types == ["input_image", "input_file"] + + +class TestAttachmentMagicAndMime: + def test_image_magic_all_formats(self) -> None: + from app.services.assessment.utils.attachments import _image_mime_from_magic + + assert _image_mime_from_magic(b"\x89PNG\r\n\x1a\n") == "image/png" + assert _image_mime_from_magic(b"\xff\xd8\xff") == "image/jpeg" + assert _image_mime_from_magic(b"GIF89a") == "image/gif" + assert _image_mime_from_magic(b"GIF87a") == "image/gif" + assert _image_mime_from_magic(b"BM....") == "image/bmp" + assert _image_mime_from_magic(b"RIFF\x00\x00\x00\x00WEBP") == "image/webp" + assert _image_mime_from_magic(b"II*\x00") == "image/tiff" + assert _image_mime_from_magic(b"MM\x00*") == "image/tiff" + assert _image_mime_from_magic(b"nope") is None + + def test_type_from_magic_pdf_and_none(self) -> None: + from app.services.assessment.utils.attachments import _type_from_magic + + assert _type_from_magic(b"%PDF-1.7") == "pdf" + assert _type_from_magic(b"\x89PNG\r\n\x1a\n") == "image" + assert _type_from_magic(b"random") is None + + def test_guess_image_mime_from_url_variants(self) -> None: + from app.services.assessment.utils.attachments import _guess_image_mime_from_url + + assert _guess_image_mime_from_url("http://x/a.PNG") == "image/png" + assert _guess_image_mime_from_url("http://x/a.jpeg") == "image/jpeg" + assert _guess_image_mime_from_url("http://x/a.webp") == "image/webp" + assert _guess_image_mime_from_url("http://x/a.txt") is None + + def test_resolve_image_mime_data_url(self) -> None: + from app.services.assessment.utils.attachments import ( + resolve_image_mime_and_payload, + ) + + mime, payload = resolve_image_mime_and_payload( + "data:image/webp;base64,AAAA", "base64" + ) + assert mime == "image/webp" + assert payload == "AAAA" + + def test_decode_base64_prefix_empty(self) -> None: + from app.services.assessment.utils.attachments import _decode_base64_prefix + + assert _decode_base64_prefix(" ") is None diff --git a/backend/app/tests/assessment/test_crud.py b/backend/app/tests/assessment/test_crud.py index 1d456329e..1cf30249e 100644 --- a/backend/app/tests/assessment/test_crud.py +++ b/backend/app/tests/assessment/test_crud.py @@ -2,7 +2,7 @@ from datetime import datetime from types import SimpleNamespace -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch from uuid import UUID import pytest @@ -25,7 +25,9 @@ list_assessments, recompute_assessment_status, update_assessment_run_status, + update_run_post_processing_config, ) +from app.crud.assessment.core import update_assessment_run_l1_stats from app.models.stt_evaluation import EvaluationType @@ -300,3 +302,77 @@ def test_recompute_commit_failure_rolls_back(self) -> None: with pytest.raises(RuntimeError): recompute_assessment_status(session=session, assessment_id=1) session.rollback.assert_called_once() + + +class TestUpdateRunPostProcessingConfig: + def test_sets_config_in_input_blob(self) -> None: + session = MagicMock() + run = SimpleNamespace(id=5, input={"text_columns": ["q"]}) + cfg = {"computed_columns": [{"name": "T", "formula": "@a"}]} + with patch("app.crud.assessment.core.flag_modified") as flag: + out = update_run_post_processing_config( + session=session, run=run, config=cfg + ) + assert out.input["post_processing_config"] == cfg + assert out.input["text_columns"] == ["q"] + flag.assert_called_once_with(run, "input") + session.commit.assert_called_once() + + def test_none_input_handled(self) -> None: + session = MagicMock() + run = SimpleNamespace(id=6, input=None) + with patch("app.crud.assessment.core.flag_modified"): + out = update_run_post_processing_config( + session=session, run=run, config=None + ) + assert out.input == {"post_processing_config": None} + + def test_commit_failure_rolls_back(self) -> None: + session = MagicMock() + session.commit.side_effect = RuntimeError("db error") + run = SimpleNamespace(id=7, input={}) + with patch("app.crud.assessment.core.flag_modified"): + with pytest.raises(RuntimeError): + update_run_post_processing_config(session=session, run=run, config={}) + session.rollback.assert_called_once() + + +class TestUpdateAssessmentRunL1Stats: + def test_sets_stats_fields(self) -> None: + session = MagicMock() + run = SimpleNamespace( + id=8, + updated_at=None, + l1_object_store_url=None, + l1_total_rows=None, + l1_total_passed=None, + l1_total_rejected=None, + ) + out = update_assessment_run_l1_stats( + session=session, + run=run, + l1_object_store_url="s3://x", + l1_total_rows=10, + l1_total_passed=7, + l1_total_rejected=3, + ) + assert out.l1_object_store_url == "s3://x" + assert out.l1_total_rows == 10 + assert out.l1_total_passed == 7 + assert out.l1_total_rejected == 3 + session.commit.assert_called_once() + + def test_commit_failure_rolls_back(self) -> None: + session = MagicMock() + session.commit.side_effect = RuntimeError("db error") + run = SimpleNamespace( + id=9, + updated_at=None, + l1_object_store_url=None, + l1_total_rows=None, + l1_total_passed=None, + l1_total_rejected=None, + ) + with pytest.raises(RuntimeError): + update_assessment_run_l1_stats(session=session, run=run, l1_total_rows=1) + session.rollback.assert_called_once() diff --git a/backend/app/tests/assessment/test_duplicate_detection.py b/backend/app/tests/assessment/test_duplicate_detection.py new file mode 100644 index 000000000..24d5ac951 --- /dev/null +++ b/backend/app/tests/assessment/test_duplicate_detection.py @@ -0,0 +1,132 @@ +"""Tests for L1 duplicate detection.""" + +import json +from unittest.mock import MagicMock + +from app.services.assessment.l1.duplicate_detection import ( + _build_combined, + _parse_verdict, + run_duplicate_detection, +) + + +def _vague_client(vague: bool, reason: str = "r") -> MagicMock: + client = MagicMock() + resp = MagicMock() + resp.text = json.dumps({"vague": vague, "reason": reason}) + client.models.generate_content.return_value = resp + return client + + +class TestBuildCombined: + def test_joins_non_empty(self) -> None: + out = _build_combined({"Problem": "p", "Solution": "s", "Empty": " "}) + assert "Problem:\np" in out + assert "Solution:\ns" in out + assert "Empty" not in out + + +class TestParseVerdict: + def test_full_fields(self) -> None: + raw = ( + "Verdict: DUPLICATE\n" + "Title: Some Idea\n" + "Source: https://x.com/a\n" + "URL: https://x.com/a\n" + "Matching sentence: a beam alarm\n" + "Reason: same mechanism" + ) + out = _parse_verdict(raw) + assert out["verdict"] == "DUPLICATE" + assert out["match_title"] == "Some Idea" + assert out["source_url"] == "https://x.com/a" + assert out["matching_sentence"] == "a beam alarm" + assert out["reason"] == "same mechanism" + + def test_unique_verdict_only(self) -> None: + out = _parse_verdict("Verdict: UNIQUE\nReason: nothing matches") + assert out["verdict"] == "UNIQUE" + assert out["match_title"] is None + + def test_regex_fallback_when_key_missing(self) -> None: + out = _parse_verdict("The result is clearly OVERLAP here.") + assert out["verdict"] == "OVERLAP" + + def test_no_verdict_stays_empty(self) -> None: + out = _parse_verdict("no decision present") + assert out["verdict"] == "" + + +class TestRunDuplicateDetection: + def test_vague_short_circuits(self) -> None: + client = _vague_client(True, "too vague") + result = run_duplicate_detection( + row_idx=0, + row={"Problem": "x"}, + columns=["Problem"], + gemini_client=client, + model="gemini-2.5-flash", + store_name="store", + ) + assert result["verdict"] == "VAGUE" + assert result["reason"] == "too vague" + # Only the vague check is called; no file-search second call. + assert client.models.generate_content.call_count == 1 + + def test_not_vague_runs_file_search(self) -> None: + client = MagicMock() + vague_resp = MagicMock() + vague_resp.text = json.dumps({"vague": False, "reason": ""}) + search_resp = MagicMock() + search_resp.text = "Verdict: UNIQUE\nReason: novel" + client.models.generate_content.side_effect = [vague_resp, search_resp] + + result = run_duplicate_detection( + row_idx=1, + row={"Problem": "p", "Solution": "s"}, + columns=["Problem", "Solution"], + gemini_client=client, + model="gemini-2.5-flash", + store_name="store", + ) + assert result["verdict"] == "UNIQUE" + assert result["reason"] == "novel" + assert result["row_id"] == "row_1" + + def test_file_search_error_returns_error_verdict(self) -> None: + client = MagicMock() + vague_resp = MagicMock() + vague_resp.text = json.dumps({"vague": False, "reason": ""}) + client.models.generate_content.side_effect = [ + vague_resp, + RuntimeError("search boom"), + ] + + result = run_duplicate_detection( + row_idx=2, + row={"Problem": "p"}, + columns=["Problem"], + gemini_client=client, + model="gemini-2.5-flash", + store_name="store", + ) + assert result["verdict"] == "ERROR" + assert "search boom" in result["reason"] + + def test_vague_check_parse_error_defaults_not_vague(self) -> None: + client = MagicMock() + bad_vague = MagicMock() + bad_vague.text = "not json" + search_resp = MagicMock() + search_resp.text = "Verdict: PARTIAL_MATCH\nTitle: T\nReason: theme" + client.models.generate_content.side_effect = [bad_vague, search_resp] + + result = run_duplicate_detection( + row_idx=3, + row={"Problem": "p"}, + columns=["Problem"], + gemini_client=client, + model="gemini-2.5-flash", + store_name="store", + ) + assert result["verdict"] == "PARTIAL_MATCH" diff --git a/backend/app/tests/assessment/test_pipeline.py b/backend/app/tests/assessment/test_pipeline.py new file mode 100644 index 000000000..faa64693e --- /dev/null +++ b/backend/app/tests/assessment/test_pipeline.py @@ -0,0 +1,151 @@ +"""Tests for the L1 pipeline orchestrator.""" + +from contextlib import ExitStack +from unittest.mock import MagicMock, patch + +from app.services.assessment.l1.pipeline import run_l1_pipeline + + +def _run() -> MagicMock: + run = MagicMock() + run.id = 99 + return run + + +def _tr(verdict: bool, decision: str = "ACCEPT") -> dict: + return { + "row_id": "row", + "verdict": verdict, + "decision": decision, + "column_relevance": {"Problem": verdict}, + "reasoning": "r", + } + + +def _patches(stack: ExitStack, *, tr_side=None, dup_return=None): + """Patch the pipeline's external deps; return the TR mock.""" + client = MagicMock() + stack.enter_context( + patch( + "app.services.assessment.l1.pipeline.GeminiClient.from_credentials", + return_value=MagicMock(client=client), + ) + ) + stack.enter_context( + patch( + "app.services.assessment.l1.pipeline.get_cloud_storage", + return_value=MagicMock(), + ) + ) + stack.enter_context( + patch( + "app.services.assessment.l1.pipeline.upload_jsonl_to_object_store", + return_value="s3://l1.json", + ) + ) + stack.enter_context( + patch("app.crud.assessment.core.update_assessment_run_l1_stats") + ) + tr_mock = stack.enter_context( + patch("app.services.assessment.l1.pipeline.run_topic_relevance") + ) + if tr_side is not None: + tr_mock.side_effect = tr_side + dup_mock = stack.enter_context( + patch("app.services.assessment.l1.pipeline.run_duplicate_detection") + ) + if dup_return is not None: + dup_mock.return_value = dup_return + return tr_mock, dup_mock + + +class TestRunL1Pipeline: + def test_no_filters_configured_passthrough(self) -> None: + rows = [{"Problem": "a"}, {"Problem": "b"}] + passed, indices, results = run_l1_pipeline( + run=_run(), + rows=rows, + l1_config={}, + session=MagicMock(), + organization_id=1, + project_id=1, + ) + assert passed == rows + assert indices == [0, 1] + assert results == [] + + def test_topic_relevance_filters_rejected_rows(self) -> None: + rows = [{"Problem": "keep"}, {"Problem": "drop"}, {"Problem": "keep2"}] + # idx 1 rejected. + side = [_tr(True), _tr(False, "REJECT"), _tr(True)] + with ExitStack() as stack: + _patches(stack, tr_side=side) + passed, indices, results = run_l1_pipeline( + run=_run(), + rows=rows, + l1_config={ + "topic_relevance": {"columns": ["Problem"], "prompt": "rubric"} + }, + session=MagicMock(), + organization_id=1, + project_id=1, + ) + assert indices == [0, 2] + assert [r["Problem"] for r in passed] == ["keep", "keep2"] + assert len(results) == 3 + assert results[1]["l1_passed"] is False + + def test_duplicate_detection_runs_on_passed_rows(self) -> None: + rows = [{"Problem": "a", "Solution": "b"}] + dup = { + "row_id": "row_0", + "verdict": "UNIQUE", + "match_title": None, + "source_url": None, + "matching_sentence": None, + "reason": "novel", + } + with ExitStack() as stack: + tr_mock, dup_mock = _patches(stack, tr_side=[_tr(True)], dup_return=dup) + _, _, results = run_l1_pipeline( + run=_run(), + rows=rows, + l1_config={ + "topic_relevance": {"columns": ["Problem"], "prompt": "rubric"}, + "duplicate_detection": {"columns": ["Problem", "Solution"]}, + }, + session=MagicMock(), + organization_id=1, + project_id=1, + ) + dup_mock.assert_called_once() + assert results[0]["duplicate_detection"]["verdict"] == "UNIQUE" + + def test_attachment_columns_filtered_to_selection(self) -> None: + from app.models.assessment import AssessmentAttachment + + rows = [{"Problem": "a", "Docs": "x", "Other": "y"}] + atts = [ + AssessmentAttachment(column="Docs", type="image", format="url"), + AssessmentAttachment(column="Other", type="image", format="url"), + ] + with ExitStack() as stack: + tr_mock, _ = _patches(stack, tr_side=[_tr(True)]) + run_l1_pipeline( + run=_run(), + rows=rows, + l1_config={ + "topic_relevance": { + "columns": ["Problem"], + "prompt": "rubric", + "attachment_columns": ["Docs"], + } + }, + session=MagicMock(), + organization_id=1, + project_id=1, + attachments=atts, + ) + # run_topic_relevance is called with only the selected attachment ("Docs"). + passed_atts = tr_mock.call_args.args[6] + assert [a.column for a in passed_atts] == ["Docs"] diff --git a/backend/app/tests/assessment/test_post_processing.py b/backend/app/tests/assessment/test_post_processing.py new file mode 100644 index 000000000..0ee7b81cc --- /dev/null +++ b/backend/app/tests/assessment/test_post_processing.py @@ -0,0 +1,212 @@ +"""Tests for the assessment export post-processing engine.""" + +from app.services.assessment.utils.post_processing import ( + apply_computed_columns, + apply_filter, + apply_post_processing, + apply_sort, + evaluate_formula, +) + + +class TestEvaluateFormula: + def test_addition(self) -> None: + assert evaluate_formula("@a + @b", {"a": 2, "b": 3}) == 5.0 + + def test_all_operators(self) -> None: + row = {"a": 10, "b": 4} + assert evaluate_formula("@a - @b", row) == 6.0 + assert evaluate_formula("@a * @b", row) == 40.0 + assert evaluate_formula("@a / @b", row) == 2.5 + assert evaluate_formula("-@a", row) == -10.0 + + def test_precedence_and_constants(self) -> None: + assert evaluate_formula("@a + @b * 0.5", {"a": 1, "b": 4}) == 3.0 + + def test_string_numeric_values_coerced(self) -> None: + assert evaluate_formula("@a + @b", {"a": "2", "b": "3"}) == 5.0 + + def test_missing_column_is_zero(self) -> None: + assert evaluate_formula("@a + @b", {"a": 5}) == 5.0 + + def test_non_numeric_value_is_zero(self) -> None: + assert evaluate_formula("@a + @b", {"a": 5, "b": "abc"}) == 5.0 + + def test_unsupported_operation_returns_none(self) -> None: + # Power operator is not in the safe-ops allowlist. + assert evaluate_formula("@a ** @b", {"a": 2, "b": 3}) is None + + def test_syntax_error_returns_none(self) -> None: + assert evaluate_formula("@a +", {"a": 1}) is None + + +class TestApplyComputedColumns: + def test_adds_column_in_place(self) -> None: + rows = [{"a": 1, "b": 2}, {"a": 3, "b": 4}] + apply_computed_columns(rows, [{"name": "total", "formula": "@a + @b"}]) + assert rows[0]["total"] == 3.0 + assert rows[1]["total"] == 7.0 + + def test_skips_empty_name_or_formula(self) -> None: + rows = [{"a": 1}] + apply_computed_columns( + rows, + [ + {"name": "", "formula": "@a"}, + {"name": "x", "formula": ""}, + ], + ) + assert rows[0] == {"a": 1} + + +class TestApplyFilter: + def test_no_rules_returns_all(self) -> None: + rows = [{"a": 1}, {"a": 2}] + assert apply_filter(rows, []) == rows + + def test_eq_ne(self) -> None: + rows = [{"x": "Yes"}, {"x": "no"}] + assert apply_filter(rows, [{"column": "x", "op": "eq", "value": "yes"}]) == [ + {"x": "Yes"} + ] + assert apply_filter(rows, [{"column": "x", "op": "ne", "value": "yes"}]) == [ + {"x": "no"} + ] + + def test_contains_not_contains(self) -> None: + rows = [{"x": "hello world"}, {"x": "bye"}] + assert apply_filter( + rows, [{"column": "x", "op": "contains", "value": "world"}] + ) == [{"x": "hello world"}] + assert apply_filter( + rows, [{"column": "x", "op": "not_contains", "value": "world"}] + ) == [{"x": "bye"}] + + def test_in_not_in(self) -> None: + rows = [{"x": "a"}, {"x": "b"}] + assert apply_filter( + rows, [{"column": "x", "op": "in", "value": ["a", "c"]}] + ) == [{"x": "a"}] + assert apply_filter( + rows, [{"column": "x", "op": "not_in", "value": ["a", "c"]}] + ) == [{"x": "b"}] + + def test_is_empty_is_not_empty(self) -> None: + rows = [{"x": ""}, {"x": "v"}, {"x": None}] + assert apply_filter(rows, [{"column": "x", "op": "is_empty"}]) == [ + {"x": ""}, + {"x": None}, + ] + assert apply_filter(rows, [{"column": "x", "op": "is_not_empty"}]) == [ + {"x": "v"} + ] + + def test_numeric_comparisons(self) -> None: + rows = [{"n": 1}, {"n": 5}, {"n": 10}] + assert apply_filter(rows, [{"column": "n", "op": "gt", "value": 4}]) == [ + {"n": 5}, + {"n": 10}, + ] + assert apply_filter(rows, [{"column": "n", "op": "lt", "value": 5}]) == [ + {"n": 1} + ] + assert apply_filter(rows, [{"column": "n", "op": "gte", "value": 5}]) == [ + {"n": 5}, + {"n": 10}, + ] + assert apply_filter(rows, [{"column": "n", "op": "lte", "value": 5}]) == [ + {"n": 1}, + {"n": 5}, + ] + + def test_numeric_filter_non_numeric_excluded(self) -> None: + rows = [{"n": "abc"}, {"n": 5}] + assert apply_filter(rows, [{"column": "n", "op": "gt", "value": 1}]) == [ + {"n": 5} + ] + + def test_unknown_op_keeps_row(self) -> None: + rows = [{"x": "a"}] + assert apply_filter(rows, [{"column": "x", "op": "weird", "value": 1}]) == rows + + def test_and_logic_across_rules(self) -> None: + rows = [{"n": 5, "x": "yes"}, {"n": 5, "x": "no"}, {"n": 1, "x": "yes"}] + out = apply_filter( + rows, + [ + {"column": "n", "op": "gte", "value": 5}, + {"column": "x", "op": "eq", "value": "yes"}, + ], + ) + assert out == [{"n": 5, "x": "yes"}] + + +class TestApplySort: + def test_no_rules_returns_input(self) -> None: + rows = [{"n": 2}, {"n": 1}] + assert apply_sort(rows, []) == rows + + def test_numeric_asc_desc(self) -> None: + rows = [{"n": 3}, {"n": 1}, {"n": 2}] + assert [ + r["n"] for r in apply_sort(rows, [{"column": "n", "direction": "asc"}]) + ] == [1, 2, 3] + assert [ + r["n"] for r in apply_sort(rows, [{"column": "n", "direction": "desc"}]) + ] == [3, 2, 1] + + def test_none_values_sort_last(self) -> None: + rows = [{"n": None}, {"n": 2}, {"n": 1}] + assert [ + r["n"] for r in apply_sort(rows, [{"column": "n", "direction": "asc"}]) + ] == [1, 2, None] + + def test_string_asc_desc(self) -> None: + rows = [{"s": "banana"}, {"s": "apple"}, {"s": "cherry"}] + assert [ + r["s"] for r in apply_sort(rows, [{"column": "s", "direction": "asc"}]) + ] == ["apple", "banana", "cherry"] + assert [ + r["s"] for r in apply_sort(rows, [{"column": "s", "direction": "desc"}]) + ] == ["cherry", "banana", "apple"] + + def test_multi_rule_priority(self) -> None: + rows = [ + {"grp": "a", "n": 2}, + {"grp": "b", "n": 1}, + {"grp": "a", "n": 1}, + ] + out = apply_sort( + rows, + [ + {"column": "grp", "direction": "asc"}, + {"column": "n", "direction": "desc"}, + ], + ) + assert out == [ + {"grp": "a", "n": 2}, + {"grp": "a", "n": 1}, + {"grp": "b", "n": 1}, + ] + + +class TestApplyPostProcessing: + def test_none_config_is_noop(self) -> None: + rows = [{"a": 1}] + assert apply_post_processing(rows, None) is rows + + def test_full_pipeline(self) -> None: + rows = [ + {"Novelty": 3, "Feasibility": 4}, + {"Novelty": 9, "Feasibility": 8}, + {"Novelty": 1, "Feasibility": 1}, + ] + config = { + "computed_columns": [ + {"name": "Total", "formula": "@Novelty + @Feasibility"} + ], + "filter": [{"column": "Total", "op": "gt", "value": 5}], + "sort": [{"column": "Total", "direction": "desc"}], + } + out = apply_post_processing(rows, config) + assert [r["Total"] for r in out] == [17.0, 7.0] From e020717949fe8b1a5b8bdf1d09207c1b6ceb7e23 Mon Sep 17 00:00:00 2001 From: Prashant Vasudevan <71649489+vprashrex@users.noreply.github.com> Date: Tue, 2 Jun 2026 17:28:06 +0530 Subject: [PATCH 8/8] feat: implement prefilter pipeline with topic relevance and duplicate detection - Added a new prefilter pipeline orchestrator that runs topic relevance and duplicate detection filters in series. - Created `run_topic_relevance` and `run_duplicate_detection` functions to handle respective filtering logic. - Updated assessment service to utilize prefilter configuration instead of L1 configuration. - Modified assessment tasks to reflect the new prefilter processing status and error handling. - Adjusted utility functions and export logic to accommodate prefilter results. - Enhanced tests to cover the new prefilter functionality and ensure proper integration. --- ...dd_prefilter_columns_to_assessment_run.py} | 24 ++--- backend/app/api/routes/assessment/runs.py | 6 +- backend/app/celery/tasks/job_execution.py | 5 +- backend/app/core/config.py | 13 ++- backend/app/crud/assessment/__init__.py | 4 +- backend/app/crud/assessment/batch.py | 2 +- backend/app/crud/assessment/core.py | 42 ++++----- backend/app/models/assessment.py | 36 ++++---- .../app/services/assessment/l1/__init__.py | 3 - .../services/assessment/prefilter/__init__.py | 3 + .../{l1 => prefilter}/duplicate_detection.py | 10 ++- .../assessment/{l1 => prefilter}/pipeline.py | 87 +++++++++--------- .../{l1 => prefilter}/topic_relevance.py | 6 +- backend/app/services/assessment/service.py | 8 +- backend/app/services/assessment/tasks.py | 36 ++++---- .../services/assessment/utils/attachments.py | 55 +++++++----- .../app/services/assessment/utils/export.py | 90 ++++++++++--------- backend/app/tests/assessment/test_batch.py | 46 ++++++++++ backend/app/tests/assessment/test_crud.py | 46 +++++----- .../assessment/test_duplicate_detection.py | 4 +- backend/app/tests/assessment/test_pipeline.py | 36 ++++---- .../tests/assessment/test_topic_relevance.py | 4 +- 22 files changed, 328 insertions(+), 238 deletions(-) rename backend/app/alembic/versions/{064_add_l1_columns_to_assessment_run.py => 064_add_prefilter_columns_to_assessment_run.py} (60%) delete mode 100644 backend/app/services/assessment/l1/__init__.py create mode 100644 backend/app/services/assessment/prefilter/__init__.py rename backend/app/services/assessment/{l1 => prefilter}/duplicate_detection.py (95%) rename backend/app/services/assessment/{l1 => prefilter}/pipeline.py (67%) rename backend/app/services/assessment/{l1 => prefilter}/topic_relevance.py (94%) diff --git a/backend/app/alembic/versions/064_add_l1_columns_to_assessment_run.py b/backend/app/alembic/versions/064_add_prefilter_columns_to_assessment_run.py similarity index 60% rename from backend/app/alembic/versions/064_add_l1_columns_to_assessment_run.py rename to backend/app/alembic/versions/064_add_prefilter_columns_to_assessment_run.py index bce33e6cd..1720e21b4 100644 --- a/backend/app/alembic/versions/064_add_l1_columns_to_assessment_run.py +++ b/backend/app/alembic/versions/064_add_prefilter_columns_to_assessment_run.py @@ -1,4 +1,4 @@ -"""Add L1 pipeline columns to assessment_run +"""Add prefilter pipeline columns to assessment_run Revision ID: 064 Revises: 063 @@ -19,25 +19,25 @@ def upgrade() -> None: op.add_column( "assessment_run", sa.Column( - "l1_object_store_url", + "prefilter_object_store_url", sa.String(), nullable=True, - comment="S3 URL of stored L1 filter results JSON", + comment="S3 URL of stored prefilter filter results JSON", ), ) op.add_column( "assessment_run", sa.Column( - "l1_total_rows", + "prefilter_total_rows", sa.Integer(), nullable=True, - comment="Total rows fed into L1 pipeline", + comment="Total rows fed into prefilter pipeline", ), ) op.add_column( "assessment_run", sa.Column( - "l1_total_passed", + "prefilter_total_passed", sa.Integer(), nullable=True, comment="Rows that passed topic relevance and went to L2", @@ -46,16 +46,16 @@ def upgrade() -> None: op.add_column( "assessment_run", sa.Column( - "l1_total_rejected", + "prefilter_total_rejected", sa.Integer(), nullable=True, - comment="Rows rejected by topic relevance, stopped at L1", + comment="Rows rejected by topic relevance, stopped at prefilter", ), ) def downgrade() -> None: - op.drop_column("assessment_run", "l1_total_rejected") - op.drop_column("assessment_run", "l1_total_passed") - op.drop_column("assessment_run", "l1_total_rows") - op.drop_column("assessment_run", "l1_object_store_url") + op.drop_column("assessment_run", "prefilter_total_rejected") + op.drop_column("assessment_run", "prefilter_total_passed") + op.drop_column("assessment_run", "prefilter_total_rows") + op.drop_column("assessment_run", "prefilter_object_store_url") diff --git a/backend/app/api/routes/assessment/runs.py b/backend/app/api/routes/assessment/runs.py index 3c3abd57a..2825e5c86 100644 --- a/backend/app/api/routes/assessment/runs.py +++ b/backend/app/api/routes/assessment/runs.py @@ -67,9 +67,9 @@ def _build_run_public( total_items=run.total_items, error_message=run.error_message, input=run.input, - l1_total_rows=run.l1_total_rows, - l1_total_passed=run.l1_total_passed, - l1_total_rejected=run.l1_total_rejected, + prefilter_total_rows=run.prefilter_total_rows, + prefilter_total_passed=run.prefilter_total_passed, + prefilter_total_rejected=run.prefilter_total_rejected, post_processing_config=(run.input or {}).get("post_processing_config"), inserted_at=run.inserted_at, updated_at=run.updated_at, diff --git a/backend/app/celery/tasks/job_execution.py b/backend/app/celery/tasks/job_execution.py index ec7ad1bd0..6a249a92e 100644 --- a/backend/app/celery/tasks/job_execution.py +++ b/backend/app/celery/tasks/job_execution.py @@ -232,9 +232,8 @@ def run_tts_batch_submission( ) -@celery_app.task( - bind=True, queue="low_priority", priority=1, soft_time_limit=1800, time_limit=2100 -) +@celery_app.task(bind=True, queue="low_priority", priority=1) +@gevent_timeout(settings.ASSESSMENT_RUN_SOFT_TIME_LIMIT, "run_assessment_run") def run_assessment_run( self, run_id: int, diff --git a/backend/app/core/config.py b/backend/app/core/config.py index 60504147b..e2e3a5ff1 100644 --- a/backend/app/core/config.py +++ b/backend/app/core/config.py @@ -172,11 +172,18 @@ def AWS_S3_BUCKET(self) -> str: PENDING_JOB_QUERY_TIMEOUT_MS: int = 1000 # Assessment - ASSESSMENT_L1_GEMINI_MODEL: str = "gemini-3.1-flash-lite" - ASSESSMENT_L1_CONCURRENT_WORKERS: int = 8 - ASSESSMENT_L1_DUPLICATE_STORE_NAME: str = ( + ASSESSMENT_PREFILTER_GEMINI_MODEL: str = "gemini-3.1-flash-lite" + ASSESSMENT_PREFILTER_CONCURRENT_WORKERS: int = 8 + ASSESSMENT_PREFILTER_DUPLICATE_STORE_NAME: str = ( "fileSearchStores/inquilabcorpus-782mxjcwisaz" ) + # Soft timeout for the full assessment run task (prefilter pipeline + batch + # submission). Larger than the default task limit because prefilter runs many + # concurrent LLM calls over the whole dataset. Seconds. Default 2 hours. + ASSESSMENT_RUN_SOFT_TIME_LIMIT: int = 7200 + # Timeout for prefilter Gemini calls to prevent pipeline stalls from slow/hung requests + # (default: 2 minutes, in ms) + ASSESSMENT_PREFILTER_REQUEST_TIMEOUT_MS: int = 120000 @computed_field # type: ignore[prop-decorator] @property diff --git a/backend/app/crud/assessment/__init__.py b/backend/app/crud/assessment/__init__.py index 8e623e3a7..2f5f6f217 100644 --- a/backend/app/crud/assessment/__init__.py +++ b/backend/app/crud/assessment/__init__.py @@ -13,7 +13,7 @@ list_assessment_runs, list_assessments, recompute_assessment_status, - update_assessment_run_l1_stats, + update_assessment_run_prefilter_stats, update_assessment_run_status, update_run_post_processing_config, ) @@ -44,7 +44,7 @@ "list_assessment_datasets", "list_assessments", "recompute_assessment_status", - "update_assessment_run_l1_stats", + "update_assessment_run_prefilter_stats", "update_assessment_run_status", "update_run_post_processing_config", ] diff --git a/backend/app/crud/assessment/batch.py b/backend/app/crud/assessment/batch.py index 531dc038d..5debc3156 100644 --- a/backend/app/crud/assessment/batch.py +++ b/backend/app/crud/assessment/batch.py @@ -322,7 +322,7 @@ def submit_assessment_batch( output_schema = assessment_input.get("output_schema") attachments = [AssessmentAttachment(**a) for a in attachments_raw] - # Use preloaded rows (post-L1 filtered) if provided, else load from dataset. + # Use preloaded rows (post-prefilter filtered) if provided, else load from dataset. if preloaded_rows is not None: rows = preloaded_rows else: diff --git a/backend/app/crud/assessment/core.py b/backend/app/crud/assessment/core.py index d5a184d06..547cc2e31 100644 --- a/backend/app/crud/assessment/core.py +++ b/backend/app/crud/assessment/core.py @@ -248,25 +248,25 @@ def update_assessment_run_status( return run -def update_assessment_run_l1_stats( +def update_assessment_run_prefilter_stats( session: Session, run: AssessmentRun, - l1_object_store_url: str | None = None, - l1_total_rows: int | None = None, - l1_total_passed: int | None = None, - l1_total_rejected: int | None = None, + prefilter_object_store_url: str | None = None, + prefilter_total_rows: int | None = None, + prefilter_total_passed: int | None = None, + prefilter_total_rejected: int | None = None, ) -> AssessmentRun: - """Persist L1 result stats (rows/passed/rejected + S3 URL) on a run.""" + """Persist prefilter result stats (rows/passed/rejected + S3 URL) on a run.""" run.updated_at = now() - if l1_object_store_url is not None: - run.l1_object_store_url = l1_object_store_url - if l1_total_rows is not None: - run.l1_total_rows = l1_total_rows - if l1_total_passed is not None: - run.l1_total_passed = l1_total_passed - if l1_total_rejected is not None: - run.l1_total_rejected = l1_total_rejected + if prefilter_object_store_url is not None: + run.prefilter_object_store_url = prefilter_object_store_url + if prefilter_total_rows is not None: + run.prefilter_total_rows = prefilter_total_rows + if prefilter_total_passed is not None: + run.prefilter_total_passed = prefilter_total_passed + if prefilter_total_rejected is not None: + run.prefilter_total_rejected = prefilter_total_rejected session.add(run) try: @@ -274,16 +274,18 @@ def update_assessment_run_l1_stats( session.refresh(run) except Exception as e: session.rollback() - logger.error(f"[update_assessment_run_l1_stats] Failed: {e}", exc_info=True) + logger.error( + f"[update_assessment_run_prefilter_stats] Failed: {e}", exc_info=True + ) raise return run _ACTIVE_RUN_STATUSES = frozenset( - {"l1_processing", "l2_processing", "processing", "in_progress"} + {"prefilter_processing", "l2_processing", "processing", "in_progress"} ) -_FAILED_RUN_STATUSES = frozenset({"failed", "l1_failed"}) +_FAILED_RUN_STATUSES = frozenset({"failed", "prefilter_failed"}) _COMPLETED_RUN_STATUSES = frozenset({"completed", "completed_with_errors"}) @@ -329,9 +331,9 @@ def build_run_stats(runs: list[AssessmentRun]) -> list[AssessmentRunStat]: total_items=run.total_items, error_message=run.error_message, updated_at=run.updated_at, - l1_total_rows=run.l1_total_rows, - l1_total_passed=run.l1_total_passed, - l1_total_rejected=run.l1_total_rejected, + prefilter_total_rows=run.prefilter_total_rows, + prefilter_total_passed=run.prefilter_total_passed, + prefilter_total_rejected=run.prefilter_total_rejected, ) for run in runs ] diff --git a/backend/app/models/assessment.py b/backend/app/models/assessment.py index b5a1a31f5..8ff468db2 100644 --- a/backend/app/models/assessment.py +++ b/backend/app/models/assessment.py @@ -109,7 +109,7 @@ class AssessmentRun(SQLModel, table=True): default="pending", sa_column_kwargs={ "comment": ( - "Unified pipeline status: pending, l1_processing, l1_failed, " + "Unified pipeline status: pending, prefilter_processing, prefilter_failed, " "l2_processing, completed, completed_with_errors, failed" ) }, @@ -141,25 +141,27 @@ class AssessmentRun(SQLModel, table=True): nullable=True, sa_column_kwargs={"comment": "S3 URL of processed L2 batch results"}, ) - l1_object_store_url: str | None = SQLField( + prefilter_object_store_url: str | None = SQLField( default=None, nullable=True, - sa_column_kwargs={"comment": "S3 URL of stored L1 filter results JSON"}, + sa_column_kwargs={"comment": "S3 URL of stored prefilter filter results JSON"}, ) - l1_total_rows: int | None = SQLField( + prefilter_total_rows: int | None = SQLField( default=None, nullable=True, - sa_column_kwargs={"comment": "Total rows fed into L1 pipeline"}, + sa_column_kwargs={"comment": "Total rows fed into prefilter pipeline"}, ) - l1_total_passed: int | None = SQLField( + prefilter_total_passed: int | None = SQLField( default=None, nullable=True, sa_column_kwargs={"comment": "Rows that passed topic relevance and went to L2"}, ) - l1_total_rejected: int | None = SQLField( + prefilter_total_rejected: int | None = SQLField( default=None, nullable=True, - sa_column_kwargs={"comment": "Rows rejected by topic relevance, stopped at L1"}, + sa_column_kwargs={ + "comment": "Rows rejected by topic relevance, stopped at prefilter" + }, ) error_message: str | None = SQLField( default=None, @@ -208,9 +210,9 @@ class AssessmentRunStat(BaseModel): total_items: int error_message: str | None = None updated_at: datetime | None = None - l1_total_rows: int | None = None - l1_total_passed: int | None = None - l1_total_rejected: int | None = None + prefilter_total_rows: int | None = None + prefilter_total_passed: int | None = None + prefilter_total_rejected: int | None = None class AssessmentPublic(BaseModel): @@ -250,9 +252,9 @@ class AssessmentRunPublic(BaseModel): "text_columns, attachments, output_schema" ), ) - l1_total_rows: int | None = None - l1_total_passed: int | None = None - l1_total_rejected: int | None = None + prefilter_total_rows: int | None = None + prefilter_total_passed: int | None = None + prefilter_total_rejected: int | None = None post_processing_config: dict[str, Any] | None = None inserted_at: datetime updated_at: datetime @@ -323,11 +325,11 @@ class AssessmentCreate(BaseModel): configs: list[AssessmentConfigRef] = Field( ..., min_length=1, max_length=4, description="Config versions to run" ) - l1_config: dict[str, Any] | None = Field( + prefilter_config: dict[str, Any] | None = Field( None, description=( - "L1 pipeline config. Keys: topic_relevance (columns, prompt), " - "duplicate_detection (columns). Omit to skip L1." + "prefilter pipeline config. Keys: topic_relevance (columns, prompt), " + "duplicate_detection (columns). Omit to skip prefilter." ), ) post_processing_config: dict[str, Any] | None = Field( diff --git a/backend/app/services/assessment/l1/__init__.py b/backend/app/services/assessment/l1/__init__.py deleted file mode 100644 index 66e3a0374..000000000 --- a/backend/app/services/assessment/l1/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from app.services.assessment.l1.pipeline import run_l1_pipeline - -__all__ = ["run_l1_pipeline"] diff --git a/backend/app/services/assessment/prefilter/__init__.py b/backend/app/services/assessment/prefilter/__init__.py new file mode 100644 index 000000000..6cd16dce2 --- /dev/null +++ b/backend/app/services/assessment/prefilter/__init__.py @@ -0,0 +1,3 @@ +from app.services.assessment.prefilter.pipeline import run_prefilter_pipeline + +__all__ = ["run_prefilter_pipeline"] diff --git a/backend/app/services/assessment/l1/duplicate_detection.py b/backend/app/services/assessment/prefilter/duplicate_detection.py similarity index 95% rename from backend/app/services/assessment/l1/duplicate_detection.py rename to backend/app/services/assessment/prefilter/duplicate_detection.py index 608389c1d..bba004457 100644 --- a/backend/app/services/assessment/l1/duplicate_detection.py +++ b/backend/app/services/assessment/prefilter/duplicate_detection.py @@ -1,4 +1,4 @@ -"""Duplicate detection filter for L1 pipeline.""" +"""Duplicate detection filter for prefilter pipeline.""" import json import logging @@ -8,6 +8,8 @@ from google import genai from google.genai import types +from app.core.config import settings + logger = logging.getLogger(__name__) _VAGUE_SYS = """ @@ -78,6 +80,9 @@ def _check_vague( system_instruction=_VAGUE_SYS, response_mime_type="application/json", temperature=0.0, + http_options=types.HttpOptions( + timeout=settings.ASSESSMENT_PREFILTER_REQUEST_TIMEOUT_MS + ), ), ) parsed = json.loads((response.text or "").strip()) @@ -104,6 +109,9 @@ def _call_file_search( ) ], temperature=0.0, + http_options=types.HttpOptions( + timeout=settings.ASSESSMENT_PREFILTER_REQUEST_TIMEOUT_MS + ), ), ) return response.text or "" diff --git a/backend/app/services/assessment/l1/pipeline.py b/backend/app/services/assessment/prefilter/pipeline.py similarity index 67% rename from backend/app/services/assessment/l1/pipeline.py rename to backend/app/services/assessment/prefilter/pipeline.py index 18df91324..131fdde8b 100644 --- a/backend/app/services/assessment/l1/pipeline.py +++ b/backend/app/services/assessment/prefilter/pipeline.py @@ -1,4 +1,4 @@ -"""L1 pipeline orchestrator. +"""prefilter pipeline orchestrator. Runs two filters in series for each row: 1. Topic Relevance (go/no-go) — REJECT stops the row. @@ -18,20 +18,22 @@ from app.core.cloud import get_cloud_storage from app.core.storage_utils import upload_jsonl_to_object_store from app.models.assessment import AssessmentAttachment, AssessmentRun -from app.services.assessment.l1.duplicate_detection import run_duplicate_detection -from app.services.assessment.l1.topic_relevance import run_topic_relevance +from app.services.assessment.prefilter.duplicate_detection import ( + run_duplicate_detection, +) +from app.services.assessment.prefilter.topic_relevance import run_topic_relevance logger = logging.getLogger(__name__) -def _build_l1_result( +def _build_prefilter_result( row_idx: int, tr_result: dict[str, Any] | None, dup_result: dict[str, Any] | None, ) -> dict[str, Any]: return { "row_id": f"row_{row_idx}", - "l1_passed": tr_result["verdict"] if tr_result else True, + "prefilter_passed": tr_result["verdict"] if tr_result else True, "topic_relevance": { "decision": tr_result["decision"], "column_relevance": tr_result.get("column_relevance") or {}, @@ -43,37 +45,40 @@ def _build_l1_result( } -def run_l1_pipeline( +def run_prefilter_pipeline( run: AssessmentRun, rows: list[dict[str, str]], - l1_config: dict[str, Any], + prefilter_config: dict[str, Any], session: Session, organization_id: int, project_id: int, attachments: list[AssessmentAttachment] | None = None, ) -> tuple[list[dict[str, str]], list[int], list[dict[str, Any]]]: - """Run L1 filters on all rows. + """Run prefilter filters on all rows. Args: run: The AssessmentRun record (used for S3 path and DB update). rows: Full dataset rows loaded from object store. - l1_config: User-supplied config with topic_relevance and duplicate_detection keys. + prefilter_config: User-supplied config with topic_relevance and duplicate_detection keys. session: DB session. organization_id: For Gemini credential lookup. project_id: For Gemini credential lookup and S3 storage. Returns: - (passed_rows, passed_indices, all_l1_results) + (passed_rows, passed_indices, all_prefilter_results) passed_rows: subset of rows where topic_relevance verdict=true. passed_indices: original dataset indices of passed_rows (used to preserve row IDs in L2). - all_l1_results: one entry per input row (len == len(rows)). + all_prefilter_results: one entry per input row (len == len(rows)). """ - model = settings.ASSESSMENT_L1_GEMINI_MODEL - workers = settings.ASSESSMENT_L1_CONCURRENT_WORKERS - store_name = settings.ASSESSMENT_L1_DUPLICATE_STORE_NAME + model = settings.ASSESSMENT_PREFILTER_GEMINI_MODEL + workers = settings.ASSESSMENT_PREFILTER_CONCURRENT_WORKERS + store_name = settings.ASSESSMENT_PREFILTER_DUPLICATE_STORE_NAME + # Future wait bound: the per-request HTTP timeout plus a small margin so a + # hung Gemini call surfaces as a future error instead of blocking forever. + future_timeout = settings.ASSESSMENT_PREFILTER_REQUEST_TIMEOUT_MS / 1000 + 30 - tr_config = l1_config.get("topic_relevance") or {} - dup_config = l1_config.get("duplicate_detection") or {} + tr_config = prefilter_config.get("topic_relevance") or {} + dup_config = prefilter_config.get("duplicate_detection") or {} tr_columns: list[str] = tr_config.get("columns") or [] tr_prompt: str = tr_config.get("prompt") or "" @@ -91,7 +96,7 @@ def run_l1_pipeline( if not tr_enabled and not dup_enabled: logger.warning( - "[run_l1_pipeline] run_id=%s — no L1 filters configured, skipping L1", + "[run_prefilter_pipeline] run_id=%s — no prefilter filters configured, skipping prefilter", run.id, ) return rows, list(range(len(rows))), [] @@ -103,7 +108,7 @@ def run_l1_pipeline( ).client logger.info( - "[run_l1_pipeline] run_id=%s | rows=%s | model=%s | workers=%s | tr=%s | dup=%s", + "[run_prefilter_pipeline] run_id=%s | rows=%s | model=%s | workers=%s | tr=%s | dup=%s", run.id, len(rows), model, @@ -136,10 +141,10 @@ def run_l1_pipeline( for fut in as_completed(futs): idx = futs[fut] try: - tr_results[idx] = fut.result() + tr_results[idx] = fut.result(timeout=future_timeout) except Exception as exc: logger.warning( - "[run_l1_pipeline] TR future error row_%s | %s", idx, exc + "[run_prefilter_pipeline] TR future error row_%s | %s", idx, exc ) tr_results[idx] = { "row_id": f"row_{idx}", @@ -156,7 +161,7 @@ def run_l1_pipeline( rejected_count = len(rows) - len(passed_indices) logger.info( - "[run_l1_pipeline] run_id=%s | TR done | passed=%s | rejected=%s", + "[run_prefilter_pipeline] run_id=%s | TR done | passed=%s | rejected=%s", run.id, len(passed_indices), rejected_count, @@ -180,10 +185,12 @@ def run_l1_pipeline( for fut in as_completed(futs): idx = futs[fut] try: - dup_results[idx] = fut.result() + dup_results[idx] = fut.result(timeout=future_timeout) except Exception as exc: logger.warning( - "[run_l1_pipeline] DUP future error row_%s | %s", idx, exc + "[run_prefilter_pipeline] DUP future error row_%s | %s", + idx, + exc, ) dup_results[idx] = { "row_id": f"row_{idx}", @@ -194,45 +201,45 @@ def run_l1_pipeline( "reason": str(exc)[:200], } - all_l1_results: list[dict[str, Any]] = [ - _build_l1_result(idx, tr_results[idx], dup_results.get(idx)) + all_prefilter_results: list[dict[str, Any]] = [ + _build_prefilter_result(idx, tr_results[idx], dup_results.get(idx)) for idx in range(len(rows)) ] - l1_object_store_url: str | None = None + prefilter_object_store_url: str | None = None try: storage = get_cloud_storage(session=session, project_id=project_id) - l1_object_store_url = upload_jsonl_to_object_store( + prefilter_object_store_url = upload_jsonl_to_object_store( storage=storage, - results=all_l1_results, - filename="l1_results.json", - subdirectory=f"assessment/run-{run.id}/l1", + results=all_prefilter_results, + filename="prefilter_results.json", + subdirectory=f"assessment/run-{run.id}/prefilter", format="json", ) logger.info( - "[run_l1_pipeline] run_id=%s | L1 results uploaded to %s", + "[run_prefilter_pipeline] run_id=%s | prefilter results uploaded to %s", run.id, - l1_object_store_url, + prefilter_object_store_url, ) except Exception as exc: logger.error( - "[run_l1_pipeline] run_id=%s | S3 upload failed | %s", + "[run_prefilter_pipeline] run_id=%s | S3 upload failed | %s", run.id, exc, exc_info=True, ) - from app.crud.assessment.core import update_assessment_run_l1_stats + from app.crud.assessment.core import update_assessment_run_prefilter_stats - update_assessment_run_l1_stats( + update_assessment_run_prefilter_stats( session=session, run=run, - l1_object_store_url=l1_object_store_url, - l1_total_rows=len(rows), - l1_total_passed=len(passed_indices), - l1_total_rejected=rejected_count, + prefilter_object_store_url=prefilter_object_store_url, + prefilter_total_rows=len(rows), + prefilter_total_passed=len(passed_indices), + prefilter_total_rejected=rejected_count, ) sorted_passed_indices = sorted(passed_indices) passed_rows = [rows[idx] for idx in sorted_passed_indices] - return passed_rows, sorted_passed_indices, all_l1_results + return passed_rows, sorted_passed_indices, all_prefilter_results diff --git a/backend/app/services/assessment/l1/topic_relevance.py b/backend/app/services/assessment/prefilter/topic_relevance.py similarity index 94% rename from backend/app/services/assessment/l1/topic_relevance.py rename to backend/app/services/assessment/prefilter/topic_relevance.py index c1894c04e..053547ab7 100644 --- a/backend/app/services/assessment/l1/topic_relevance.py +++ b/backend/app/services/assessment/prefilter/topic_relevance.py @@ -1,4 +1,4 @@ -"""Topic relevance filter for L1 pipeline. +"""Topic relevance filter for prefilter pipeline. """ import json @@ -8,6 +8,7 @@ from google import genai from google.genai import types +from app.core.config import settings from app.models.assessment import AssessmentAttachment from app.services.assessment.utils.attachments import build_gemini_attachment_parts @@ -89,6 +90,9 @@ def run_topic_relevance( response_mime_type="application/json", response_schema=output_schema, temperature=0.0, + http_options=types.HttpOptions( + timeout=settings.ASSESSMENT_PREFILTER_REQUEST_TIMEOUT_MS + ), ), ) raw = (response.text or "").strip() diff --git a/backend/app/services/assessment/service.py b/backend/app/services/assessment/service.py index cabe2bb4c..b2e2cea05 100644 --- a/backend/app/services/assessment/service.py +++ b/backend/app/services/assessment/service.py @@ -79,7 +79,7 @@ def _build_retry_request( attachments=[AssessmentAttachment.model_validate(item) for item in attachments], output_schema=assessment_input.get("output_schema"), configs=configs, - l1_config=assessment_input.get("l1_config"), + prefilter_config=assessment_input.get("prefilter_config"), post_processing_config=assessment_input.get("post_processing_config"), ) @@ -93,7 +93,7 @@ def start_assessment( """Validate, create Assessment + AssessmentRun records, dispatch Celery tasks. Each run is created with status='pending' and handed off to a Celery worker - that runs L1 filtering then submits the L2 batch. + that runs prefilter filtering then submits the L2 batch. """ from app.celery.tasks.job_execution import run_assessment_run @@ -120,8 +120,8 @@ def start_assessment( } if request.output_schema: assessment_input["output_schema"] = request.output_schema - if request.l1_config: - assessment_input["l1_config"] = request.l1_config + if request.prefilter_config: + assessment_input["prefilter_config"] = request.prefilter_config if request.post_processing_config: assessment_input["post_processing_config"] = request.post_processing_config diff --git a/backend/app/services/assessment/tasks.py b/backend/app/services/assessment/tasks.py index 295c55ad2..909a89a25 100644 --- a/backend/app/services/assessment/tasks.py +++ b/backend/app/services/assessment/tasks.py @@ -1,4 +1,4 @@ -"""Celery task logic for running a single assessment run (L1 → L2 batch submit).""" +"""Celery task logic for running a single assessment run (prefilter → L2 batch submit).""" import logging @@ -19,7 +19,7 @@ AssessmentRun, ) from app.models.config.config import ConfigTag -from app.services.assessment.l1 import run_l1_pipeline +from app.services.assessment.prefilter import run_prefilter_pipeline logger = logging.getLogger(__name__) @@ -29,12 +29,12 @@ def execute_assessment_run( organization_id: int, project_id: int, ) -> None: - """Run L1 filtering then submit L2 batch for one AssessmentRun. + """Run prefilter filtering then submit L2 batch for one AssessmentRun. Status transitions: - pending → l1_processing → l1_failed (stop) + pending → prefilter_processing → prefilter_failed (stop) → l2_processing → (cron handles rest) - pending → l2_processing (when no l1_config) + pending → l2_processing (when no prefilter_config) """ with Session(engine) as session: run = session.get(AssessmentRun, run_id) @@ -116,19 +116,19 @@ def execute_assessment_run( recompute_assessment_status(session=session, assessment_id=assessment.id) return - # L1 pipeline + # prefilter pipeline rows_for_l2 = all_rows row_indices_for_l2: list[int] | None = None - l1_config = assessment_input.get("l1_config") - if l1_config: + prefilter_config = assessment_input.get("prefilter_config") + if prefilter_config: update_assessment_run_status( - session=session, run=run, status="l1_processing" + session=session, run=run, status="prefilter_processing" ) try: - rows_for_l2, row_indices_for_l2, _ = run_l1_pipeline( + rows_for_l2, row_indices_for_l2, _ = run_prefilter_pipeline( run=run, rows=all_rows, - l1_config=l1_config, + prefilter_config=prefilter_config, session=session, organization_id=organization_id, project_id=project_id, @@ -138,28 +138,28 @@ def execute_assessment_run( ], ) logger.info( - "[execute_assessment_run] L1 done | run_id=%s | rows_to_l2=%s / %s", + "[execute_assessment_run] prefilter done | run_id=%s | rows_to_l2=%s / %s", run_id, len(rows_for_l2), len(all_rows), ) - except Exception as l1_exc: + except Exception as prefilter_exc: logger.error( - "[execute_assessment_run] L1 failed run_id=%s | %s", + "[execute_assessment_run] prefilter failed run_id=%s | %s", run_id, - l1_exc, + prefilter_exc, exc_info=True, ) update_assessment_run_status( session=session, run=run, - status="l1_failed", - error_message=f"L1 pipeline failed: {l1_exc}", + status="prefilter_failed", + error_message=f"prefilter pipeline failed: {prefilter_exc}", ) recompute_assessment_status( session=session, assessment_id=assessment.id ) - return # L2 does not run when L1 fails + return # L2 does not run when prefilter fails # L2 batch submit try: diff --git a/backend/app/services/assessment/utils/attachments.py b/backend/app/services/assessment/utils/attachments.py index 3622f9bce..87ca3aba7 100644 --- a/backend/app/services/assessment/utils/attachments.py +++ b/backend/app/services/assessment/utils/attachments.py @@ -9,11 +9,12 @@ import logging import re from typing import Any -from urllib.parse import urlparse +from urllib.parse import urljoin, urlparse import requests from app.models.assessment import AssessmentAttachment +from app.utils import validate_callback_url logger = logging.getLogger(__name__) @@ -177,33 +178,43 @@ def _type_from_content_type(content_type: str | None) -> str | None: return None +_PROBE_MAX_REDIRECTS = 3 + + def _probe_url_type(url: str, num_bytes: int = 16) -> str | None: """Probe a remote URL's type: ranged byte sniff first, Content-Type fallback. - - Reads only the first few bytes (does not download the whole file). Drive - share URLs are routed through the download endpoint so the real file bytes - are read instead of an HTML share page. - """ + Handles Google Drive URLs with the same logic as to_direct_attachment_url, since""" file_id = _drive_file_id(url) - probe_url = ( + current = ( f"https://drive.google.com/uc?export=download&id={file_id}" if file_id else url ) try: - with requests.get( - probe_url, - headers={"Range": f"bytes=0-{num_bytes - 1}"}, - timeout=10, - stream=True, - allow_redirects=True, - ) as resp: - resp.raise_for_status() - for chunk in resp.iter_content(chunk_size=num_bytes): - magic_type = _type_from_magic(chunk) - if magic_type: - return magic_type - break - return _type_from_content_type(resp.headers.get("Content-Type")) + for _ in range(_PROBE_MAX_REDIRECTS + 1): + validate_callback_url(current) + with requests.get( + current, + headers={"Range": f"bytes=0-{num_bytes - 1}"}, + timeout=10, + stream=True, + allow_redirects=False, + ) as resp: + location = resp.headers.get("Location") + if resp.is_redirect and location: + current = urljoin(current, location) + continue + resp.raise_for_status() + for chunk in resp.iter_content(chunk_size=num_bytes): + magic_type = _type_from_magic(chunk) + if magic_type: + return magic_type + break + return _type_from_content_type(resp.headers.get("Content-Type")) + logger.warning(f"[_probe_url_type] Too many redirects probing {url}") + return None + except ValueError as e: + logger.warning(f"[_probe_url_type] Blocked unsafe probe URL {url}: {e}") + return None except requests.RequestException as e: logger.warning(f"[_probe_url_type] Probe failed for {url}: {e}") return None @@ -312,7 +323,7 @@ def build_gemini_attachment_parts( """Convert one dataset cell into one or more Gemini content parts. Mirrors the per-item type detection used for the L2 batch so the same - image/pdf routing applies to L1 (topic relevance) calls. + image/pdf routing applies to prefilter (topic relevance) calls. """ value = value.strip() if not value: diff --git a/backend/app/services/assessment/utils/export.py b/backend/app/services/assessment/utils/export.py index 86d9186b0..39fa7691c 100644 --- a/backend/app/services/assessment/utils/export.py +++ b/backend/app/services/assessment/utils/export.py @@ -22,7 +22,7 @@ from app.services.assessment.utils.parsing import parse_stored_results, usage_totals from app.utils import APIResponse -_L1_JSON_COLUMNS = ["topic_relevance", "duplicate_detection"] +_PREFILTER_JSON_COLUMNS = ["topic_relevance", "duplicate_detection"] logger = logging.getLogger(__name__) @@ -36,23 +36,23 @@ def _load_dataset_rows( return load_dataset_rows(session, dataset) -def _load_l1_results( +def _load_prefilter_results( session: Session, run: AssessmentRun, assessment: Assessment, ) -> dict[str, dict[str, Any]]: - """Load L1 results from object store, keyed by row_id. Returns {} if unavailable.""" - if not run.l1_object_store_url: + """Load prefilter results from object store, keyed by row_id. Returns {} if unavailable.""" + if not run.prefilter_object_store_url: return {} try: storage = get_cloud_storage(session, project_id=assessment.project_id) - body = storage.stream(run.l1_object_store_url) + body = storage.stream(run.prefilter_object_store_url) raw = body.read().decode("utf-8") results: list[dict[str, Any]] = json.loads(raw) return {str(item["row_id"]): item for item in results if "row_id" in item} except Exception as exc: logger.warning( - "[_load_l1_results] Failed to load L1 results for run id=%s: %s", + "[_load_prefilter_results] Failed to load prefilter results for run id=%s: %s", run.id, exc, ) @@ -163,32 +163,34 @@ def _expand_output_columns( """ row_payload, input_col_names = _expand_input_columns(row_payload) - json_expand_cols = {"output", "input_data"} | set(_L1_JSON_COLUMNS) + json_expand_cols = {"output", "input_data"} | set(_PREFILTER_JSON_COLUMNS) base_fields = [ field for field in AssessmentExportRow.model_fields.keys() if field not in json_expand_cols ] - # L1 columns are prefixed with their parent name to avoid key collisions + # prefilter columns are prefixed with their parent name to avoid key collisions parsed_cols: dict[str, list[dict[str, Any] | None]] = { - col: [] for col in ["output"] + _L1_JSON_COLUMNS + col: [] for col in ["output"] + _PREFILTER_JSON_COLUMNS + } + col_keys: dict[str, list[str]] = { + col: [] for col in ["output"] + _PREFILTER_JSON_COLUMNS } - col_keys: dict[str, list[str]] = {col: [] for col in ["output"] + _L1_JSON_COLUMNS} col_seen: dict[str, dict[str, None]] = { - col: {} for col in ["output"] + _L1_JSON_COLUMNS + col: {} for col in ["output"] + _PREFILTER_JSON_COLUMNS } has_unparsed_output = False for row in row_payload: - for col in ["output"] + _L1_JSON_COLUMNS: + for col in ["output"] + _PREFILTER_JSON_COLUMNS: parsed = _parse_json_col(row.get(col)) if parsed is None and col == "output" and row.get(col) is not None: has_unparsed_output = True parsed_cols[col].append(parsed) if parsed: for k in parsed: - prefixed = f"{col}_{k}" if col in _L1_JSON_COLUMNS else k + prefixed = f"{col}_{k}" if col in _PREFILTER_JSON_COLUMNS else k if prefixed not in col_seen[col]: col_seen[col][prefixed] = None col_keys[col].append(prefixed) @@ -196,7 +198,7 @@ def _expand_output_columns( def _get_prefixed(parsed: dict[str, Any] | None, col: str) -> dict[str, Any]: if not parsed: return {} - if col in _L1_JSON_COLUMNS: + if col in _PREFILTER_JSON_COLUMNS: return {f"{col}_{k}": v for k, v in parsed.items()} return parsed @@ -204,7 +206,7 @@ def _get_prefixed(parsed: dict[str, Any] | None, col: str) -> dict[str, Any]: expanded: list[dict[str, Any]] = [] for i, row in enumerate(row_payload): new_row = {k: v for k, v in row.items() if k not in json_expand_cols} - for col in ["output"] + _L1_JSON_COLUMNS: + for col in ["output"] + _PREFILTER_JSON_COLUMNS: parsed = parsed_cols[col][i] keys = col_keys[col] prefixed_vals = _get_prefixed(parsed, col) @@ -218,22 +220,22 @@ def _get_prefixed(parsed: dict[str, Any] | None, col: str) -> dict[str, Any]: new_row["output_raw"] = row.get("output") expanded.append(new_row) - l1_keys = col_keys["topic_relevance"] + col_keys["duplicate_detection"] + prefilter_keys = col_keys["topic_relevance"] + col_keys["duplicate_detection"] output_keys = col_keys["output"] - all_output_keys = l1_keys + output_keys + all_output_keys = prefilter_keys + output_keys if not all_output_keys: fieldnames = input_col_names + list(AssessmentExportRow.model_fields.keys()) fieldnames = [f for f in fieldnames if f != "input_data"] return row_payload, fieldnames, input_col_names, [], [] - fieldnames = input_col_names + l1_keys + output_keys + base_fields + fieldnames = input_col_names + prefilter_keys + output_keys + base_fields if has_unparsed_output: fieldnames.insert( - len(input_col_names) + len(l1_keys) + len(output_keys), "output_raw" + len(input_col_names) + len(prefilter_keys) + len(output_keys), "output_raw" ) - return expanded, fieldnames, input_col_names, l1_keys, output_keys + return expanded, fieldnames, input_col_names, prefilter_keys, output_keys def serialize_export_rows( @@ -258,7 +260,7 @@ def serialize_export_rows( expanded, fieldnames, input_col_names, - l1_keys, + prefilter_keys, output_keys, ) = _expand_output_columns(row_payload) expanded = apply_post_processing(expanded, post_processing_config) @@ -288,8 +290,8 @@ def serialize_export_rows( detail="XLSX export requires pandas/openpyxl support in the backend runtime", ) from exc - # Explicit ordering: inputs → L1 → L2 → computed columns - excel_fields = input_col_names + l1_keys + output_keys + computed_names + # Explicit ordering: inputs → prefilter → L2 → computed columns + excel_fields = input_col_names + prefilter_keys + output_keys + computed_names if not excel_fields: excel_fields = output_keys or ["output"] @@ -431,15 +433,15 @@ def _load_dataset_rows_for_run( return [] -def _extract_l1_json_columns( - l1_item: dict[str, Any] | None, +def _extract_prefilter_json_columns( + prefilter_item: dict[str, Any] | None, ) -> dict[str, Any]: """Return topic_relevance and duplicate_detection as JSON strings for export expansion.""" - if not l1_item: + if not prefilter_item: return {"topic_relevance": None, "duplicate_detection": None} - tr = l1_item.get("topic_relevance") - dup = l1_item.get("duplicate_detection") + tr = prefilter_item.get("topic_relevance") + dup = prefilter_item.get("duplicate_detection") tr_flat: dict[str, Any] | None = None if tr: @@ -468,10 +470,10 @@ def load_export_rows_for_run( ) -> list[AssessmentExportRow]: """Load flattened export rows for a single child assessment run. - When L1 results exist, ALL dataset rows are included in output. - L1-rejected rows have L1 columns filled and L2 columns empty. - L1-passed rows have all columns filled. - Without L1, behaviour is unchanged (only L2 result rows returned). + When prefilter results exist, ALL dataset rows are included in output. + prefilter-rejected rows have prefilter columns filled and L2 columns empty. + prefilter-passed rows have all columns filled. + Without prefilter, behaviour is unchanged (only L2 result rows returned). """ if assessment is None: assessment = session.get(Assessment, run.assessment_id) @@ -486,8 +488,8 @@ def load_export_rows_for_run( dataset_name = dataset.name if dataset else None dataset_rows = _load_dataset_rows_for_run(session, run, assessment) - # Load L1 results (empty dict if no L1 was run) - l1_by_row_id = _load_l1_results(session, run, assessment) + # Load prefilter results (empty dict if no prefilter was run) + prefilter_by_row_id = _load_prefilter_results(session, run, assessment) # Load L2 results (may be None if batch not complete) l2_by_row_id: dict[str, dict[str, Any]] = {} @@ -504,24 +506,24 @@ def load_export_rows_for_run( if "row_id" in item } - has_l1 = bool(l1_by_row_id) + has_prefilter = bool(prefilter_by_row_id) - if has_l1 and dataset_rows: + if has_prefilter and dataset_rows: # All rows in output — build from full dataset export_rows: list[AssessmentExportRow] = [] for row_idx, input_data in enumerate(dataset_rows): row_id_str = f"row_{row_idx}" - l1_item = l1_by_row_id.get(row_id_str) - l1_cols = _extract_l1_json_columns(l1_item) + prefilter_item = prefilter_by_row_id.get(row_id_str) + prefilter_cols = _extract_prefilter_json_columns(prefilter_item) l2_item = l2_by_row_id.get(row_id_str) input_tokens, output_tokens, total_tokens = usage_totals( l2_item.get("usage") if l2_item else None ) - l1_passed = (l1_item or {}).get("l1_passed", True) + prefilter_passed = (prefilter_item or {}).get("prefilter_passed", True) result_status = ( - "l1_rejected" - if not l1_passed + "prefilter_rejected" + if not prefilter_passed else ("failed" if l2_item and l2_item.get("error") else "passed") ) @@ -539,8 +541,8 @@ def load_export_rows_for_run( row_id=row_id_str, result_status=result_status, input_data=input_data, - topic_relevance=l1_cols.get("topic_relevance"), - duplicate_detection=l1_cols.get("duplicate_detection"), + topic_relevance=prefilter_cols.get("topic_relevance"), + duplicate_detection=prefilter_cols.get("duplicate_detection"), output=l2_item.get("output") if l2_item else None, error=l2_item.get("error") if l2_item else None, response_id=l2_item.get("response_id") if l2_item else None, @@ -552,7 +554,7 @@ def load_export_rows_for_run( ) return export_rows - # No L1 — original behaviour: only L2 result rows + # No prefilter — original behaviour: only L2 result rows if not run.batch_job_id: logger.warning( "[load_export_rows_for_run] No batch_job_id for run id=%s", run.id diff --git a/backend/app/tests/assessment/test_batch.py b/backend/app/tests/assessment/test_batch.py index aa0fce1a0..38373774c 100644 --- a/backend/app/tests/assessment/test_batch.py +++ b/backend/app/tests/assessment/test_batch.py @@ -470,9 +470,12 @@ def test_url_no_extension_probes_bytes(self) -> None: resp = MagicMock() resp.__enter__ = MagicMock(return_value=resp) resp.__exit__ = MagicMock(return_value=False) + resp.is_redirect = False resp.raise_for_status = MagicMock() resp.iter_content = MagicMock(return_value=iter([b"%PDF-1.7"])) with patch( + "app.services.assessment.utils.attachments.validate_callback_url" + ), patch( "app.services.assessment.utils.attachments.requests.get", return_value=resp, ) as mock_get: @@ -485,10 +488,13 @@ def test_url_probe_uses_content_type_when_no_magic(self) -> None: resp = MagicMock() resp.__enter__ = MagicMock(return_value=resp) resp.__exit__ = MagicMock(return_value=False) + resp.is_redirect = False resp.raise_for_status = MagicMock() resp.iter_content = MagicMock(return_value=iter([b"\x00\x01\x02\x03"])) resp.headers = {"Content-Type": "application/pdf; charset=binary"} with patch( + "app.services.assessment.utils.attachments.validate_callback_url" + ), patch( "app.services.assessment.utils.attachments.requests.get", return_value=resp, ): @@ -499,20 +505,60 @@ def test_url_probe_failure_falls_back(self) -> None: url = "https://example.com/file" with patch( + "app.services.assessment.utils.attachments.validate_callback_url" + ), patch( "app.services.assessment.utils.attachments.requests.get", side_effect=_requests.RequestException("boom"), ): assert detect_item_type(url, "url", "image", {}) == "image" + def test_url_probe_follows_validated_redirect(self) -> None: + """A redirect hop is followed and re-validated before the next request.""" + url = "https://drive.google.com/file/d/RID/view" + redirect = MagicMock() + redirect.__enter__ = MagicMock(return_value=redirect) + redirect.__exit__ = MagicMock(return_value=False) + redirect.is_redirect = True + redirect.headers = {"Location": "https://files.example.com/real.pdf"} + final = MagicMock() + final.__enter__ = MagicMock(return_value=final) + final.__exit__ = MagicMock(return_value=False) + final.is_redirect = False + final.raise_for_status = MagicMock() + final.iter_content = MagicMock(return_value=iter([b"%PDF-1.7"])) + with patch( + "app.services.assessment.utils.attachments.validate_callback_url" + ) as validate, patch( + "app.services.assessment.utils.attachments.requests.get", + side_effect=[redirect, final], + ) as mock_get: + assert detect_item_type(url, "url", "image", {}) == "pdf" + # Both the initial and redirected URLs were validated and fetched. + assert validate.call_count == 2 + assert mock_get.call_count == 2 + + def test_url_probe_blocked_by_ssrf_falls_back(self) -> None: + url = "https://internal.host/file" + with patch( + "app.services.assessment.utils.attachments.validate_callback_url", + side_effect=ValueError("private IP"), + ), patch("app.services.assessment.utils.attachments.requests.get") as mock_get: + # SSRF guard blocks the probe -> falls back to declared type. + assert detect_item_type(url, "url", "pdf", {}) == "pdf" + mock_get.assert_not_called() + def test_cache_skips_second_probe(self) -> None: url = "https://drive.google.com/file/d/XYZ/view" cache: dict[str, str] = {} resp = MagicMock() resp.__enter__ = MagicMock(return_value=resp) resp.__exit__ = MagicMock(return_value=False) + resp.is_redirect = False resp.raise_for_status = MagicMock() resp.iter_content = MagicMock(return_value=iter([b"%PDF-1.7"])) with patch( + "app.services.assessment.utils.attachments.validate_callback_url" + ), patch( "app.services.assessment.utils.attachments.requests.get", return_value=resp, ) as mock_get: diff --git a/backend/app/tests/assessment/test_crud.py b/backend/app/tests/assessment/test_crud.py index 1cf30249e..e2f44a21a 100644 --- a/backend/app/tests/assessment/test_crud.py +++ b/backend/app/tests/assessment/test_crud.py @@ -27,7 +27,7 @@ update_assessment_run_status, update_run_post_processing_config, ) -from app.crud.assessment.core import update_assessment_run_l1_stats +from app.crud.assessment.core import update_assessment_run_prefilter_stats from app.models.stt_evaluation import EvaluationType @@ -234,9 +234,9 @@ def test_build_run_stats(self) -> None: total_items=2, error_message=None, updated_at=datetime(2024, 1, 1), - l1_total_rows=None, - l1_total_passed=None, - l1_total_rejected=None, + prefilter_total_rows=None, + prefilter_total_passed=None, + prefilter_total_rejected=None, ), ] stats = build_run_stats(runs) @@ -343,23 +343,23 @@ def test_sets_stats_fields(self) -> None: run = SimpleNamespace( id=8, updated_at=None, - l1_object_store_url=None, - l1_total_rows=None, - l1_total_passed=None, - l1_total_rejected=None, + prefilter_object_store_url=None, + prefilter_total_rows=None, + prefilter_total_passed=None, + prefilter_total_rejected=None, ) - out = update_assessment_run_l1_stats( + out = update_assessment_run_prefilter_stats( session=session, run=run, - l1_object_store_url="s3://x", - l1_total_rows=10, - l1_total_passed=7, - l1_total_rejected=3, + prefilter_object_store_url="s3://x", + prefilter_total_rows=10, + prefilter_total_passed=7, + prefilter_total_rejected=3, ) - assert out.l1_object_store_url == "s3://x" - assert out.l1_total_rows == 10 - assert out.l1_total_passed == 7 - assert out.l1_total_rejected == 3 + assert out.prefilter_object_store_url == "s3://x" + assert out.prefilter_total_rows == 10 + assert out.prefilter_total_passed == 7 + assert out.prefilter_total_rejected == 3 session.commit.assert_called_once() def test_commit_failure_rolls_back(self) -> None: @@ -368,11 +368,13 @@ def test_commit_failure_rolls_back(self) -> None: run = SimpleNamespace( id=9, updated_at=None, - l1_object_store_url=None, - l1_total_rows=None, - l1_total_passed=None, - l1_total_rejected=None, + prefilter_object_store_url=None, + prefilter_total_rows=None, + prefilter_total_passed=None, + prefilter_total_rejected=None, ) with pytest.raises(RuntimeError): - update_assessment_run_l1_stats(session=session, run=run, l1_total_rows=1) + update_assessment_run_prefilter_stats( + session=session, run=run, prefilter_total_rows=1 + ) session.rollback.assert_called_once() diff --git a/backend/app/tests/assessment/test_duplicate_detection.py b/backend/app/tests/assessment/test_duplicate_detection.py index 24d5ac951..5d363f896 100644 --- a/backend/app/tests/assessment/test_duplicate_detection.py +++ b/backend/app/tests/assessment/test_duplicate_detection.py @@ -1,9 +1,9 @@ -"""Tests for L1 duplicate detection.""" +"""Tests for prefilter duplicate detection.""" import json from unittest.mock import MagicMock -from app.services.assessment.l1.duplicate_detection import ( +from app.services.assessment.prefilter.duplicate_detection import ( _build_combined, _parse_verdict, run_duplicate_detection, diff --git a/backend/app/tests/assessment/test_pipeline.py b/backend/app/tests/assessment/test_pipeline.py index faa64693e..d74841650 100644 --- a/backend/app/tests/assessment/test_pipeline.py +++ b/backend/app/tests/assessment/test_pipeline.py @@ -1,9 +1,9 @@ -"""Tests for the L1 pipeline orchestrator.""" +"""Tests for the prefilter pipeline orchestrator.""" from contextlib import ExitStack from unittest.mock import MagicMock, patch -from app.services.assessment.l1.pipeline import run_l1_pipeline +from app.services.assessment.prefilter.pipeline import run_prefilter_pipeline def _run() -> MagicMock: @@ -27,32 +27,32 @@ def _patches(stack: ExitStack, *, tr_side=None, dup_return=None): client = MagicMock() stack.enter_context( patch( - "app.services.assessment.l1.pipeline.GeminiClient.from_credentials", + "app.services.assessment.prefilter.pipeline.GeminiClient.from_credentials", return_value=MagicMock(client=client), ) ) stack.enter_context( patch( - "app.services.assessment.l1.pipeline.get_cloud_storage", + "app.services.assessment.prefilter.pipeline.get_cloud_storage", return_value=MagicMock(), ) ) stack.enter_context( patch( - "app.services.assessment.l1.pipeline.upload_jsonl_to_object_store", - return_value="s3://l1.json", + "app.services.assessment.prefilter.pipeline.upload_jsonl_to_object_store", + return_value="s3://prefilter.json", ) ) stack.enter_context( - patch("app.crud.assessment.core.update_assessment_run_l1_stats") + patch("app.crud.assessment.core.update_assessment_run_prefilter_stats") ) tr_mock = stack.enter_context( - patch("app.services.assessment.l1.pipeline.run_topic_relevance") + patch("app.services.assessment.prefilter.pipeline.run_topic_relevance") ) if tr_side is not None: tr_mock.side_effect = tr_side dup_mock = stack.enter_context( - patch("app.services.assessment.l1.pipeline.run_duplicate_detection") + patch("app.services.assessment.prefilter.pipeline.run_duplicate_detection") ) if dup_return is not None: dup_mock.return_value = dup_return @@ -62,10 +62,10 @@ def _patches(stack: ExitStack, *, tr_side=None, dup_return=None): class TestRunL1Pipeline: def test_no_filters_configured_passthrough(self) -> None: rows = [{"Problem": "a"}, {"Problem": "b"}] - passed, indices, results = run_l1_pipeline( + passed, indices, results = run_prefilter_pipeline( run=_run(), rows=rows, - l1_config={}, + prefilter_config={}, session=MagicMock(), organization_id=1, project_id=1, @@ -80,10 +80,10 @@ def test_topic_relevance_filters_rejected_rows(self) -> None: side = [_tr(True), _tr(False, "REJECT"), _tr(True)] with ExitStack() as stack: _patches(stack, tr_side=side) - passed, indices, results = run_l1_pipeline( + passed, indices, results = run_prefilter_pipeline( run=_run(), rows=rows, - l1_config={ + prefilter_config={ "topic_relevance": {"columns": ["Problem"], "prompt": "rubric"} }, session=MagicMock(), @@ -93,7 +93,7 @@ def test_topic_relevance_filters_rejected_rows(self) -> None: assert indices == [0, 2] assert [r["Problem"] for r in passed] == ["keep", "keep2"] assert len(results) == 3 - assert results[1]["l1_passed"] is False + assert results[1]["prefilter_passed"] is False def test_duplicate_detection_runs_on_passed_rows(self) -> None: rows = [{"Problem": "a", "Solution": "b"}] @@ -107,10 +107,10 @@ def test_duplicate_detection_runs_on_passed_rows(self) -> None: } with ExitStack() as stack: tr_mock, dup_mock = _patches(stack, tr_side=[_tr(True)], dup_return=dup) - _, _, results = run_l1_pipeline( + _, _, results = run_prefilter_pipeline( run=_run(), rows=rows, - l1_config={ + prefilter_config={ "topic_relevance": {"columns": ["Problem"], "prompt": "rubric"}, "duplicate_detection": {"columns": ["Problem", "Solution"]}, }, @@ -131,10 +131,10 @@ def test_attachment_columns_filtered_to_selection(self) -> None: ] with ExitStack() as stack: tr_mock, _ = _patches(stack, tr_side=[_tr(True)]) - run_l1_pipeline( + run_prefilter_pipeline( run=_run(), rows=rows, - l1_config={ + prefilter_config={ "topic_relevance": { "columns": ["Problem"], "prompt": "rubric", diff --git a/backend/app/tests/assessment/test_topic_relevance.py b/backend/app/tests/assessment/test_topic_relevance.py index ad52c2306..064d4476b 100644 --- a/backend/app/tests/assessment/test_topic_relevance.py +++ b/backend/app/tests/assessment/test_topic_relevance.py @@ -1,10 +1,10 @@ -"""Tests for L1 topic relevance attachment handling.""" +"""Tests for prefilter topic relevance attachment handling.""" import json from unittest.mock import MagicMock from app.models.assessment import AssessmentAttachment -from app.services.assessment.l1.topic_relevance import run_topic_relevance +from app.services.assessment.prefilter.topic_relevance import run_topic_relevance def _client_returning(decision: str) -> MagicMock: