Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
64 changes: 60 additions & 4 deletions src/landingai_ade/_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,6 +324,7 @@ def classify(
document: Optional[FileTypes] | Omit = omit,
document_url: Optional[str] | Omit = omit,
model: Optional[str] | Omit = omit,
save_to: str | Path | None = None,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
Expand Down Expand Up @@ -356,6 +357,11 @@ def classify(

model: Classification model version. Defaults to the latest.

save_to: Optional output path. If a directory, auto-generates the filename
(e.g. {input_file}_classify_output.json, or classify_output.json when no
input filename is available). If a full path ending in .json, saves there
directly. Parent directories are created automatically.

extra_headers: Send extra headers

extra_query: Add additional query parameters to the request
Expand All @@ -364,6 +370,10 @@ def classify(

timeout: Override the client-level default timeout for this request, in seconds
"""
# Store original inputs for filename extraction
original_document = document
original_document_url = document_url

body = deepcopy_with_paths(
{
"classes": classes,
Expand All @@ -378,7 +388,7 @@ def classify(
# sent to the server will contain a `boundary` parameter, e.g.
# multipart/form-data; boundary=---abc--
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return self.post(
result = self.post(
"/v1/ade/classify",
body=maybe_transform(body, client_classify_params.ClientClassifyParams),
files=files,
Expand All @@ -387,6 +397,10 @@ def classify(
),
cast_to=ClassifyResponse,
)
if save_to:
Comment thread
mikesoennichsen marked this conversation as resolved.
filename = _get_input_filename(original_document, original_document_url)
_save_response(save_to, filename, "classify", result)
return result

def extract(
self,
Expand Down Expand Up @@ -667,6 +681,7 @@ def section(
markdown: Union[FileTypes, str, None] | Omit = omit,
markdown_url: Optional[str] | Omit = omit,
model: Optional[str] | Omit = omit,
save_to: str | Path | None = None,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
Expand Down Expand Up @@ -696,6 +711,11 @@ def section(

model: Section model version. Defaults to latest.

save_to: Optional output path. If a directory, auto-generates the filename
(e.g. {input_file}_section_output.json, or section_output.json when no
input filename is available). If a full path ending in .json, saves there
directly. Parent directories are created automatically.

extra_headers: Send extra headers

extra_query: Add additional query parameters to the request
Expand All @@ -704,6 +724,10 @@ def section(

timeout: Override the client-level default timeout for this request, in seconds
"""
# Store original inputs for filename extraction
original_markdown = markdown
original_markdown_url = markdown_url

body = deepcopy_with_paths(
{
"guidelines": guidelines,
Expand All @@ -718,7 +742,7 @@ def section(
# sent to the server will contain a `boundary` parameter, e.g.
# multipart/form-data; boundary=---abc--
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return self.post(
result = self.post(
"/v1/ade/section",
body=maybe_transform(body, client_section_params.ClientSectionParams),
files=files,
Expand All @@ -727,6 +751,10 @@ def section(
),
cast_to=SectionResponse,
)
if save_to:
Comment thread
mikesoennichsen marked this conversation as resolved.
filename = _get_input_filename(original_markdown, original_markdown_url)
_save_response(save_to, filename, "section", result)
return result

def split(
self,
Expand Down Expand Up @@ -1014,6 +1042,7 @@ async def classify(
document: Optional[FileTypes] | Omit = omit,
document_url: Optional[str] | Omit = omit,
model: Optional[str] | Omit = omit,
save_to: str | Path | None = None,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
Expand Down Expand Up @@ -1046,6 +1075,11 @@ async def classify(

model: Classification model version. Defaults to the latest.

save_to: Optional output path. If a directory, auto-generates the filename
(e.g. {input_file}_classify_output.json, or classify_output.json when no
input filename is available). If a full path ending in .json, saves there
directly. Parent directories are created automatically.

extra_headers: Send extra headers

extra_query: Add additional query parameters to the request
Expand All @@ -1054,6 +1088,10 @@ async def classify(

timeout: Override the client-level default timeout for this request, in seconds
"""
# Store original inputs for filename extraction
original_document = document
original_document_url = document_url

body = deepcopy_with_paths(
{
"classes": classes,
Expand All @@ -1068,7 +1106,7 @@ async def classify(
# sent to the server will contain a `boundary` parameter, e.g.
# multipart/form-data; boundary=---abc--
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return await self.post(
result = await self.post(
"/v1/ade/classify",
body=await async_maybe_transform(body, client_classify_params.ClientClassifyParams),
files=files,
Expand All @@ -1077,6 +1115,10 @@ async def classify(
),
cast_to=ClassifyResponse,
)
if save_to:
Comment thread
mikesoennichsen marked this conversation as resolved.
filename = _get_input_filename(original_document, original_document_url)
_save_response(save_to, filename, "classify", result)
return result

async def extract(
self,
Expand Down Expand Up @@ -1358,6 +1400,7 @@ async def section(
markdown: Union[FileTypes, str, None] | Omit = omit,
markdown_url: Optional[str] | Omit = omit,
model: Optional[str] | Omit = omit,
save_to: str | Path | None = None,
# Use the following arguments if you need to pass additional parameters to the API that aren't available via kwargs.
# The extra values given here take precedence over values defined on the client or passed to this method.
extra_headers: Headers | None = None,
Expand Down Expand Up @@ -1387,6 +1430,11 @@ async def section(

model: Section model version. Defaults to latest.

save_to: Optional output path. If a directory, auto-generates the filename
(e.g. {input_file}_section_output.json, or section_output.json when no
input filename is available). If a full path ending in .json, saves there
directly. Parent directories are created automatically.

extra_headers: Send extra headers

extra_query: Add additional query parameters to the request
Expand All @@ -1395,6 +1443,10 @@ async def section(

timeout: Override the client-level default timeout for this request, in seconds
"""
# Store original inputs for filename extraction
original_markdown = markdown
original_markdown_url = markdown_url

body = deepcopy_with_paths(
{
"guidelines": guidelines,
Expand All @@ -1409,7 +1461,7 @@ async def section(
# sent to the server will contain a `boundary` parameter, e.g.
# multipart/form-data; boundary=---abc--
extra_headers = {"Content-Type": "multipart/form-data", **(extra_headers or {})}
return await self.post(
result = await self.post(
"/v1/ade/section",
body=await async_maybe_transform(body, client_section_params.ClientSectionParams),
files=files,
Expand All @@ -1418,6 +1470,10 @@ async def section(
),
cast_to=SectionResponse,
)
if save_to:
Comment thread
mikesoennichsen marked this conversation as resolved.
filename = _get_input_filename(original_markdown, original_markdown_url)
_save_response(save_to, filename, "section", result)
return result

async def split(
self,
Expand Down
98 changes: 95 additions & 3 deletions tests/test_save_to.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@

import pytest

from landingai_ade import AsyncLandingAIADE
from landingai_ade import LandingAIADE, AsyncLandingAIADE
from landingai_ade._client import _save_response, _get_input_filename
from landingai_ade._exceptions import LandingAiadeError

Expand Down Expand Up @@ -128,7 +128,7 @@ def test_correct_filename_format(self, tmp_path: Path) -> None:
mock_result = MagicMock()
mock_result.to_json.return_value = "{}"

for method in ["parse", "extract", "split"]:
for method in ["parse", "extract", "split", "classify", "section"]:
_save_response(tmp_path, "myinput", method, mock_result)
expected_file = tmp_path / f"myinput_{method}_output.json"
assert expected_file.exists(), f"Expected {expected_file} to exist"
Expand Down Expand Up @@ -165,7 +165,7 @@ def test_output_filename_skips_redundant_prefix(self, tmp_path: Path) -> None:
mock_result = MagicMock()
mock_result.to_json.return_value = "{}"

for method in ["parse", "extract", "split"]:
for method in ["parse", "extract", "split", "classify", "section"]:
_save_response(tmp_path, "output", method, mock_result)
expected = tmp_path / f"{method}_output.json"
assert expected.exists(), f"Expected {expected} to exist"
Expand Down Expand Up @@ -266,6 +266,62 @@ async def test_async_split_save_to(self, tmp_path: Path, mock_response: MagicMoc

assert (tmp_path / "doc_split_output.json").exists()

@pytest.mark.asyncio
async def test_async_classify_save_to_directory(self, tmp_path: Path, mock_response: MagicMock) -> None:
from unittest.mock import AsyncMock, patch

async with AsyncLandingAIADE(apikey="test-key", base_url="http://localhost") as client:
with patch.object(client, "post", new_callable=AsyncMock, return_value=mock_response):
await client.classify(
classes=[{"class": "invoice"}],
document=Path("/path/to/doc.pdf"),
save_to=tmp_path,
)

assert (tmp_path / "doc_classify_output.json").exists()

@pytest.mark.asyncio
async def test_async_classify_save_to_json_path(self, tmp_path: Path, mock_response: MagicMock) -> None:
from unittest.mock import AsyncMock, patch

output_file = tmp_path / "custom.json"
async with AsyncLandingAIADE(apikey="test-key", base_url="http://localhost") as client:
with patch.object(client, "post", new_callable=AsyncMock, return_value=mock_response):
await client.classify(
classes=[{"class": "invoice"}],
document_url="https://example.com/doc.pdf",
save_to=output_file,
)

assert output_file.exists()

@pytest.mark.asyncio
async def test_async_section_save_to_directory(self, tmp_path: Path, mock_response: MagicMock) -> None:
from unittest.mock import AsyncMock, patch

async with AsyncLandingAIADE(apikey="test-key", base_url="http://localhost") as client:
with patch.object(client, "post", new_callable=AsyncMock, return_value=mock_response):
await client.section(
markdown=Path("/path/to/doc.md"),
save_to=tmp_path,
)

assert (tmp_path / "doc_section_output.json").exists()

@pytest.mark.asyncio
async def test_async_section_save_to_with_string_input(self, tmp_path: Path, mock_response: MagicMock) -> None:
"""Section commonly receives raw markdown strings — filename should fall back to 'output'."""
from unittest.mock import AsyncMock, patch

async with AsyncLandingAIADE(apikey="test-key", base_url="http://localhost") as client:
with patch.object(client, "post", new_callable=AsyncMock, return_value=mock_response):
await client.section(
markdown="# Heading\n\nbody content",
save_to=tmp_path,
)

assert (tmp_path / "section_output.json").exists()

@pytest.mark.asyncio
async def test_async_no_save_when_save_to_none(self, tmp_path: Path, mock_response: MagicMock) -> None:
from unittest.mock import AsyncMock, patch
Expand All @@ -276,3 +332,39 @@ async def test_async_no_save_when_save_to_none(self, tmp_path: Path, mock_respon

assert result is mock_response
assert not list(tmp_path.iterdir())


class TestSyncSaveTo:
"""Tests that sync client methods accept save_to and save correctly for classify/section."""

@pytest.fixture
def mock_response(self) -> MagicMock:
mock = MagicMock()
mock.to_json.return_value = '{"result": "ok"}'
return mock

def test_sync_classify_save_to_directory(self, tmp_path: Path, mock_response: MagicMock) -> None:
from unittest.mock import patch

client = LandingAIADE(apikey="test-key", base_url="http://localhost")
with patch.object(client, "post", return_value=mock_response):
client.classify(
classes=[{"class": "invoice"}],
document=Path("/path/to/doc.pdf"),
save_to=tmp_path,
)
Comment thread
mikesoennichsen marked this conversation as resolved.

assert (tmp_path / "doc_classify_output.json").exists()

def test_sync_section_save_to_json_path(self, tmp_path: Path, mock_response: MagicMock) -> None:
from unittest.mock import patch

output_file = tmp_path / "toc.json"
client = LandingAIADE(apikey="test-key", base_url="http://localhost")
with patch.object(client, "post", return_value=mock_response):
client.section(
markdown=Path("/path/to/doc.md"),
save_to=output_file,
)
Comment thread
mikesoennichsen marked this conversation as resolved.

assert output_file.exists()