diff --git a/hamilton/driver.py b/hamilton/driver.py index 9ca646ae9..dc063c475 100644 --- a/hamilton/driver.py +++ b/hamilton/driver.py @@ -793,6 +793,19 @@ def list_available_variables( results = [Variable.from_node(n) for n in all_nodes] return results + def get_variable(self, name: str) -> Variable: + """Returns a variable by name. + + :param name: Name of the variable to return. + :return: Matching Variable. + :raises KeyError: If the variable does not exist in this Driver's graph. + """ + return Variable.from_node(self.graph.nodes[name]) + + def get_graph(self) -> graph_types.HamiltonGraph: + """Returns the public HamiltonGraph representation for this Driver.""" + return graph_types.HamiltonGraph.from_graph(self.graph) + @capture_function_usage def display_all_functions( self, diff --git a/tests/test_hamilton_driver.py b/tests/test_hamilton_driver.py index ffea258fc..f2e521db4 100644 --- a/tests/test_hamilton_driver.py +++ b/tests/test_hamilton_driver.py @@ -216,6 +216,24 @@ def test_driver_variables_exposes_original_function(): assert originating_functions["a"] == (tests.resources.very_simple_dag.b,) # a is an input +def test_driver_variable_lookup(): + dr = Driver({}, tests.resources.very_simple_dag) + + assert dr.get_variable("b").name == "b" + assert dr.get_variable("a").is_external_input is True + + with pytest.raises(KeyError): + dr.get_variable("missing") + + +def test_driver_get_graph_returns_hamilton_graph(): + dr = Driver({}, tests.resources.very_simple_dag) + + hamilton_graph = dr.get_graph() + + assert hamilton_graph["b"].name == "b" + + @pytest.mark.parametrize( "driver_factory", [