Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 7 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,13 @@ You should see an answer followed by a `Sources:` block listing the URLs used.

## Using it

While inside the interactive CLI (`ask>`), you can use the following commands to control the session and view metadata:
| Command | Description |
| :--- | :--- |
| `:stats` | Quick check on how many sources and chunks are loaded in the DB. |
| `:verbose` | Toggles verbose mode. On every query, it'll dump the retrieved chunks (URLs, match scores, and snippets) right before the response. |
| `exit` or `quit` | Safely terminates the interactive session and returns to your terminal shell. (You can also use `Ctrl-D`). |

A grounded answer looks like this:

```
Expand Down
45 changes: 43 additions & 2 deletions src/apps/dev_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
from src.config.logger import get_logger
from src.infrastructure.db import async_session_factory
from src.infrastructure.db.repository import Repository
from src.retrieval.services import retrieval_service

log = get_logger(__name__)

Expand All @@ -17,15 +18,43 @@ async def _check_db() -> None:
)


async def _print_db_status():
async with async_session_factory() as session:
count_sources, count_chunks = await Repository.get_source_and_chunk_counts(session)
if count_chunks == 0:
print(
"WARNING: The database has no chunks. Run `make ingest` first, "
"or your questions will all be answered with 'I don't know'.\n"
)
print(f"{count_sources} sources, {count_chunks} chunks loaded")


async def _repl() -> None:
await _check_db()
await _print_db_status()
verbose = False # flag for :verbose

print("cs-assistant dev CLI. Type 'exit' or Ctrl-D to quit.\n")
print("Type ':stats' or ':verbose' for cmds.\n")

while True:
try:
question = input("ask> ").strip()
except (EOFError, KeyboardInterrupt):
print("\nbye")
return

# :stats cmd [THIS NEEDS FIXING/REFACTORING]
if question.lower() in {":stats"}:
await _print_db_status()
continue

# :verbose cmd
if question.lower() in {":verbose"}:
verbose = not verbose
print(f"Verbose mode: {'ON' if verbose else 'OFF'}")
continue

# exit/quit cmd
if question.lower() in {"exit", "quit"}:
return
if not question:
Expand All @@ -37,11 +66,23 @@ async def _repl() -> None:
print(f"\nError: {e}\n")
continue

# printing out chunk content (verbose mode)
if verbose:
retrieved_chunks = await retrieval_service.get_relevant_chunks(question)
for chunk_item in retrieved_chunks:
source_url = chunk_item.chunk.source_url
similarity_score = chunk_item.score
snippet = chunk_item.chunk.content[:250] + "[...]"
print(f"URL: {source_url}")
print(f"Similarity score: {similarity_score}")
print(f"Content snippet: {snippet}")
print("-" * 60)

print(f"\n{answer.text}\n")
if answer.sources:
print("Sources:")
for source in answer.sources:
print(f" - {source.url}")
print(f"{source.url}")
print()


Expand Down
18 changes: 18 additions & 0 deletions src/infrastructure/db/repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,24 @@ async def has_chunks(session: AsyncSession) -> bool:
result = await session.execute(select(ChunkRow.id).limit(1))
return result.scalar_one_or_none() is not None

@staticmethod
async def get_source_and_chunk_counts(session: AsyncSession) -> tuple[int, int]:

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

These methods currently don't have tests, even though all other repository methods have tests implemented. For consistency, implement unit tests for these in tests/infrastructure/db/test_repository.py.

Follow the pattern / conventions of the other tests in that file. Run tests before committing to verify behaviour and that they pass.

count_sources = await Repository.count_sources(session)
count_chunks = await Repository.count_chunks(session)
return count_sources, count_chunks

@staticmethod
async def count_chunks(session: AsyncSession) -> int:
result = await session.execute(select(func.count(ChunkRow.id)))
# if above doesn't work properly

Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Lines 26-27 look like temporary code that isn't needed since they're commented out? If yes, please remove them from the final PR.

# result = await session.execute(select(func.count().select_from(ChunkRow)))
return result.scalar_one()

@staticmethod
async def count_sources(session: AsyncSession) -> int:
result = await session.execute(select(func.count(SourceRow.id)))
return result.scalar_one()

@staticmethod
async def get_or_create_source(
session: AsyncSession, *, name: str, url: str, source_type: str
Expand Down
28 changes: 28 additions & 0 deletions tests/infrastructure/db/test_repository.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,3 +119,31 @@ async def test_upsert_idempotent_and_updates_on_conflict(session: AsyncSession):
)
row = row_result.scalar_one()
assert row.content == "version 2", "Upsert should update content on conflict"


async def test_get_source_and_chunk_counts(session: AsyncSession):
initial_sources, initial_chunks = await Repository.get_source_and_chunk_counts(session)

source = await Repository.get_or_create_source(
session, name="Test Source", url="https://example.com/test", source_type="html"
)

current_sources, current_chunks = await Repository.get_source_and_chunk_counts(session)
assert current_sources == initial_sources + 1
assert current_chunks == initial_chunks
embedding = [0.1] * settings.embedding_dim

await Repository.upsert_chunk(
session,
content="Test content about Carleton CS",
embedding=embedding,
source_url="https://example.com/test",
source_type="html",
section_heading="Overview",
content_hash="hash_query_test_001",
source_id=source.id,
)

final_sources, final_chunks = await Repository.get_source_and_chunk_counts(session)
assert final_sources == initial_sources + 1
assert final_chunks == initial_chunks + 1
Loading