From ea40bc019a7df060fac438fc83d68ee4b827ef72 Mon Sep 17 00:00:00 2001 From: Asish Kumar Date: Mon, 25 May 2026 03:15:06 +0530 Subject: [PATCH 1/2] Add driver variable lookup helpers --- hamilton/driver.py | 21 +++++++++++++++++++++ tests/test_hamilton_driver.py | 16 ++++++++++++++++ 2 files changed, 37 insertions(+) diff --git a/hamilton/driver.py b/hamilton/driver.py index 541b1da02..76c0a69ce 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -16,6 +16,7 @@ # under the License. import abc +import functools import importlib import importlib.util import json @@ -792,6 +793,26 @@ def list_available_variables( results = [Variable.from_node(n) for n in all_nodes] return results + @functools.cached_property + def variables(self) -> dict[str, Variable]: + """Returns all variables in the graph keyed by name.""" + return { + node_name: Variable.from_node(node_) for node_name, node_ in self.graph.nodes.items() + } + + def get_variable(self, name: str) -> Variable: + """Returns a variable by name. + + :param name: Name of the variable to return. + :return: Matching HamiltonNode. + :raises KeyError: If the variable does not exist in this Driver's graph. + """ + return self.variables[name] + + def get_graph(self) -> graph_types.HamiltonGraph: + """Returns the public HamiltonGraph representation for this Driver.""" + return graph_types.HamiltonGraph.from_graph(self.graph) + @capture_function_usage def display_all_functions( self, diff --git a/tests/test_hamilton_driver.py b/tests/test_hamilton_driver.py index ffea258fc..e87a46275 100644 --- a/tests/test_hamilton_driver.py +++ b/tests/test_hamilton_driver.py @@ -216,6 +216,22 @@ def test_driver_variables_exposes_original_function(): assert originating_functions["a"] == (tests.resources.very_simple_dag.b,) # a is an input +def test_driver_variable_lookup(): + dr = Driver({}, tests.resources.very_simple_dag) + + assert set(dr.variables) == {"a", "b"} + assert dr.variables["b"].name == "b" + assert dr.get_variable("a").is_external_input is True + + +def test_driver_get_graph_returns_hamilton_graph(): + dr = Driver({}, tests.resources.very_simple_dag) + + hamilton_graph = dr.get_graph() + + assert hamilton_graph["b"].name == "b" + + @pytest.mark.parametrize( "driver_factory", [ From 4206838c608679a17790ebb054841417db001e96 Mon Sep 17 00:00:00 2001 From: Asish Kumar Date: Wed, 1 Jul 2026 07:47:07 +0000 Subject: [PATCH 2/2] Remove Driver variables cache --- hamilton/driver.py | 12 ++---------- tests/test_hamilton_driver.py | 6 ++++-- 2 files changed, 6 insertions(+), 12 deletions(-) diff --git a/hamilton/driver.py b/hamilton/driver.py index e6828939e..dc063c475 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -16,7 +16,6 @@ # under the License. import abc -import functools import importlib import importlib.util import json @@ -794,21 +793,14 @@ def list_available_variables( results = [Variable.from_node(n) for n in all_nodes] return results - @functools.cached_property - def variables(self) -> dict[str, Variable]: - """Returns all variables in the graph keyed by name.""" - return { - node_name: Variable.from_node(node_) for node_name, node_ in self.graph.nodes.items() - } - def get_variable(self, name: str) -> Variable: """Returns a variable by name. :param name: Name of the variable to return. - :return: Matching HamiltonNode. + :return: Matching Variable. :raises KeyError: If the variable does not exist in this Driver's graph. """ - return self.variables[name] + return Variable.from_node(self.graph.nodes[name]) def get_graph(self) -> graph_types.HamiltonGraph: """Returns the public HamiltonGraph representation for this Driver.""" diff --git a/tests/test_hamilton_driver.py b/tests/test_hamilton_driver.py index e87a46275..f2e521db4 100644 --- a/tests/test_hamilton_driver.py +++ b/tests/test_hamilton_driver.py @@ -219,10 +219,12 @@ def test_driver_variables_exposes_original_function(): def test_driver_variable_lookup(): dr = Driver({}, tests.resources.very_simple_dag) - assert set(dr.variables) == {"a", "b"} - assert dr.variables["b"].name == "b" + assert dr.get_variable("b").name == "b" assert dr.get_variable("a").is_external_input is True + with pytest.raises(KeyError): + dr.get_variable("missing") + def test_driver_get_graph_returns_hamilton_graph(): dr = Driver({}, tests.resources.very_simple_dag)