diff --git a/CHANGES.rst b/CHANGES.rst index d2afd0ad..175545b5 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -2,6 +2,12 @@ Changes for crate ================= +Unreleased +================ + +- Breaking change: ``connect()`` now raises ``ConnectionError`` immediately if +no configured server node responds. + 2026/06/17 2.2.1 ================ diff --git a/docs/by-example/client.rst b/docs/by-example/client.rst index 995ee745..52897375 100644 --- a/docs/by-example/client.rst +++ b/docs/by-example/client.rst @@ -29,12 +29,8 @@ respond, the request is automatically routed to the next server: >>> connection = client.connect([invalid_host, crate_host]) >>> connection.close() -If no ``servers`` are given, the default one ``http://127.0.0.1:4200`` is used: - - >>> connection = client.connect() - >>> connection.client._active_servers - ['http://127.0.0.1:4200'] - >>> connection.close() +If no ``servers`` are supplied to the ``connect`` method, the default address +``http://127.0.0.1:4200`` is used. If the option ``error_trace`` is set to ``True``, the client will print a whole traceback if a server error occurs: @@ -77,7 +73,7 @@ connect: The username for trusted users can also be provided in the URL: - >>> connection = client.connect(['http://trusted_me@' + crate_host]) + >>> connection = client.connect([crate_host.replace('://', '://trusted_me@')]) >>> connection.client.username 'trusted_me' >>> connection.client.password @@ -97,7 +93,7 @@ also need to provide ``password`` as argument for the ``connect()`` call: The authentication credentials can also be provided in the URL: - >>> connection = client.connect(['http://me:my_secret_pw@' + crate_host]) + >>> connection = client.connect([crate_host.replace('://', '://me:my_secret_pw@')]) >>> connection.client.username 'me' >>> connection.client.password diff --git a/docs/connect.rst b/docs/connect.rst index afc8f59c..7fa7c2eb 100644 --- a/docs/connect.rst +++ b/docs/connect.rst @@ -73,13 +73,32 @@ You can pass in as many node URLs as you like. .. TIP:: - For every query, the client will attempt to connect to each node in sequence + When ``connect()`` is called, the client contacts each node to check + availability and determine the lowest server version. If no node responds, + a ``ConnectionError`` is raised immediately. + + For every subsequent query, the client will attempt each node in sequence until a successful connection is made. Nodes are moved to the end of the list each time they are tried. Over multiple query executions, this behaviour functions as client-side *round-robin* load balancing. (This is analogous to `round-robin DNS`_.) +.. NOTE:: + + Wrap ``connect()`` in a ``try/except`` block to handle an unreachable + cluster gracefully: + + .. code-block:: python + + from crate import client + from crate.client.exceptions import ConnectionError + + try: + connection = client.connect(["node-1:4200", "node-2:4200"]) + except ConnectionError as e: + print(f"Could not reach CrateDB cluster: {e}") + .. _connection-options: Connection options diff --git a/src/crate/client/connection.py b/src/crate/client/connection.py index f722b848..e499ce14 100644 --- a/src/crate/client/connection.py +++ b/src/crate/client/connection.py @@ -18,7 +18,6 @@ # However, if you have executed another commercial license agreement # with Crate these terms will supersede the license and you may use the # software solely pursuant to the terms of the relevant commercial agreement. - from typing import Union from verlib2 import Version @@ -211,14 +210,21 @@ def get_blob_container(self, container_name): def _lowest_server_version(self): lowest = None - for server in self.client.active_servers: + servers = self.client.active_servers + connection_errors = [] + for server in servers: try: _, _, version = self.client.server_infos(server) version = Version(version) - except (ValueError, ConnectionError): + except ConnectionError as ex: + connection_errors.append(ex) + continue + except ValueError: continue if not lowest or version < lowest: lowest = version + if connection_errors and not lowest: + raise ConnectionError("; ".join(str(e) for e in connection_errors)) return lowest or Version("0.0.0") def __repr__(self): diff --git a/tests/client/test_connection.py b/tests/client/test_connection.py index 90b121f2..464ec123 100644 --- a/tests/client/test_connection.py +++ b/tests/client/test_connection.py @@ -4,6 +4,7 @@ import pytest from urllib3 import Timeout +import crate.client.exceptions from crate.client import connect from crate.client.connection import Connection from crate.client.exceptions import ProgrammingError @@ -12,6 +13,40 @@ from .settings import crate_host +class _FakeClient: + """ + Minimal stand-in for Client that lets tests control server_infos. + """ + + def __init__(self, servers, server_infos_fn): + self._servers = list(servers) + self._server_infos_fn = server_infos_fn + + @property + def active_servers(self): + return list(self._servers) + + def server_infos(self, server): + return self._server_infos_fn(server) + + +def _bare_conn(client): + """ + Create a Connection that bypasses __init__. + """ + + conn = Connection.__new__(Connection) + conn.client = client + return conn + + +def test_invalid_server_address(): + client = Client(servers="localhost:1234") + with pytest.raises(crate.client.exceptions.ConnectionError) as excinfo: + connect(client=client) + assert excinfo.match("Server not available") + + def test_lowest_server_version(): """ Verify the lowest server version is correctly set. @@ -55,10 +90,13 @@ def test_connection_closes_access(): def test_connection_closes_context_manager(): """Verify that the context manager of the client closes the connection""" - with patch.object(connect, "close", autospec=True) as close_fn: - with connect(): - pass - close_fn.assert_called_once() + with patch.object( + Client, "server_infos", return_value=(None, None, "0.0.0") + ): + with patch.object(connect, "close", autospec=True) as close_fn: + with connect(): + pass + close_fn.assert_called_once() def test_invalid_server_version(): @@ -78,8 +116,11 @@ def test_context_manager(): """ close_method = "crate.client.http.Client.close" with patch(close_method, return_value=MagicMock()) as close_func: - with connect("localhost:4200") as conn: - assert not conn._closed + with patch.object( + Client, "server_infos", return_value=(None, None, "0.0.0") + ): + with connect("localhost:4200") as conn: + assert not conn._closed assert conn._closed # Checks that the close method of the client @@ -115,7 +156,10 @@ def test_default_repr(): """ Verify default repr dunder method. """ - conn = connect() + with patch.object( + Client, "server_infos", return_value=(None, None, "0.0.0") + ): + conn = connect() assert repr(conn) == ">" @@ -132,7 +176,10 @@ def test_with_timezone(): """ tz_mst = datetime.timezone(datetime.timedelta(hours=7), name="MST") - connection = connect("localhost:4200", time_zone=tz_mst) + with patch.object( + Client, "server_infos", return_value=(None, None, "0.0.0") + ): + connection = connect("localhost:4200", time_zone=tz_mst) cursor = connection.cursor() assert cursor.time_zone.tzname(None) == "MST" @@ -148,16 +195,22 @@ def test_timeout_float(): """ Verify setting the timeout value as a scalar (float) works. """ - with connect("localhost:4200", timeout=2.42) as conn: - assert conn.client._pool_kw["timeout"] == 2.42 + with patch.object( + Client, "server_infos", return_value=(None, None, "0.0.0") + ): + with connect("localhost:4200", timeout=2.42) as conn: + assert conn.client._pool_kw["timeout"] == 2.42 def test_timeout_string(): """ Verify setting the timeout value as a scalar (string) works. """ - with connect("localhost:4200", timeout="2.42") as conn: - assert conn.client._pool_kw["timeout"] == 2.42 + with patch.object( + Client, "server_infos", return_value=(None, None, "0.0.0") + ): + with connect("localhost:4200", timeout="2.42") as conn: + assert conn.client._pool_kw["timeout"] == 2.42 def test_timeout_object(): @@ -165,5 +218,102 @@ def test_timeout_object(): Verify setting the timeout value as a Timeout object works. """ timeout = Timeout(connect=2.42, read=0.01) - with connect("localhost:4200", timeout=timeout) as conn: - assert conn.client._pool_kw["timeout"] == timeout + with patch.object( + Client, "server_infos", return_value=(None, None, "0.0.0") + ): + with connect("localhost:4200", timeout=timeout) as conn: + assert conn.client._pool_kw["timeout"] == timeout + + +def test_partial_failure_raises(): + """ + When some servers fail with ConnectionError and others produce an + unparseable version string (triggering ValueError/InvalidVersion), + the method must still raise rather than silently returning Version("0.0.0"). + + Risk: len(connection_errors) < server_count because only ConnectionError + instances are counted, so the all-failed guard never fires. + """ + + def server_infos(server): + if "4200" in server: + raise crate.client.exceptions.ConnectionError( + "Server not available" + ) + # "bad-version" triggers InvalidVersion inside Version(), which is + # caught by the second except branch and never appended to + # connection_errors. + return (None, None, "bad-version") + + client = _FakeClient( + ["http://localhost:4200", "http://localhost:4201"], + server_infos, + ) + conn = _bare_conn(client) + + with pytest.raises(crate.client.exceptions.ConnectionError): + conn._lowest_server_version() + + +def test_error_message_contains_individual_errors(): + """ + When all servers fail with ConnectionError the raised exception message + must contain each individual server's error text so operators can see + which nodes are down. + """ + msgs = { + "http://localhost:4200": "node-A refused connection", + "http://localhost:4201": "node-B timed out", + } + + def server_infos(server): + raise crate.client.exceptions.ConnectionError(msgs[server]) + + client = _FakeClient(list(msgs), server_infos) + conn = _bare_conn(client) + + with pytest.raises(crate.client.exceptions.ConnectionError) as excinfo: + conn._lowest_server_version() + + msg = str(excinfo.value) + assert "node-A refused connection" in msg + assert "node-B timed out" in msg + + +def test_active_servers_double_evaluation(): + """ + active_servers is evaluated twice: once for len() (server_count) and once + for the for-loop. If more servers appear between the two calls, every + iterated server can fail with ConnectionError yet len(connection_errors) + exceeds the stale server_count, causing the all-failed guard to miss. + + """ + + class _UnstableClient: + def __init__(self): + self._calls = 0 + + @property + def active_servers(self): + self._calls += 1 + if self._calls == 1: + # First access: len() call — reports 2 servers. + return ["http://localhost:4200", "http://localhost:4201"] + # Second access: for-loop — a third server appeared concurrently. + return [ + "http://localhost:4200", + "http://localhost:4201", + "http://localhost:4202", + ] + + def server_infos(self, server): + raise crate.client.exceptions.ConnectionError( + "Server not available" + ) + + conn = _bare_conn(_UnstableClient()) + + # All 3 iterated servers fail, but server_count=2 (stale). + # 3 != 2 → guard never fires → silently returns Version("0.0.0"). + with pytest.raises(crate.client.exceptions.ConnectionError): + conn._lowest_server_version() diff --git a/tests/client/test_http.py b/tests/client/test_http.py index 573fca68..3510de41 100644 --- a/tests/client/test_http.py +++ b/tests/client/test_http.py @@ -655,7 +655,14 @@ def do_POST(self): time.sleep(timeout + 0.1) def do_GET(self): - pass + body = json.dumps( + {"name": "test", "version": {"number": "0.0.0"}} + ).encode() + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) # Start the http server. with serve_http(TimeoutRequestHandler) as (server, url): @@ -710,7 +717,14 @@ def do_POST(self): self.wfile.write(response.encode("utf-8")) def do_GET(self): - pass + body = json.dumps( + {"name": "test", "version": {"number": "0.0.0"}} + ).encode() + self.send_response(200) + self.send_header("Content-Type", "application/json") + self.send_header("Content-Length", str(len(body))) + self.end_headers() + self.wfile.write(body) def test_default_schema(serve_http):