From 7e4134dfa8c68545e7f8185c6f439827b39d472b Mon Sep 17 00:00:00 2001 From: Luke Melton Date: Sat, 27 Jun 2026 14:09:26 -0400 Subject: [PATCH] Fix collect visualization edge styling --- hamilton/graph.py | 9 ++++++++- tests/test_graph.py | 27 +++++++++++++++++++++++++++ 2 files changed, 35 insertions(+), 1 deletion(-) diff --git a/hamilton/graph.py b/hamilton/graph.py index 6e6caa864..a22b51851 100644 --- a/hamilton/graph.py +++ b/hamilton/graph.py @@ -437,6 +437,13 @@ def _get_edge_style(from_type: str, to_type: str) -> dict: return edge_style + def _is_collect_dependency_edge(dependency_node: node.Node, target_node: node.Node) -> bool: + """Returns true for the edge that feeds a Collect[...] dependency.""" + return ( + target_node.node_role == node.NodeType.COLLECT + and dependency_node.name == target_node.collect_dependency + ) + def _get_legend( node_types: set[str], extra_legend_nodes: dict[tuple[str, str], dict[str, str]] ): @@ -648,7 +655,6 @@ def _insert_space_after_colon(col: str) -> str: # create edges input_sets = dict() for n in nodes: - to_type = "collect" if n.node_role == node.NodeType.COLLECT else "" to_modifiers = node_modifiers.get(n.name, set()) input_nodes = set() @@ -664,6 +670,7 @@ def _insert_space_after_colon(col: str) -> str: continue from_type = "expand" if d.node_role == node.NodeType.EXPAND else "" + to_type = "collect" if _is_collect_dependency_edge(d, n) else "" dependency_modifiers = node_modifiers.get(d.name, set()) edge_style = _get_edge_style(from_type, to_type) if ( diff --git a/tests/test_graph.py b/tests/test_graph.py index bbf473d79..b20000184 100644 --- a/tests/test_graph.py +++ b/tests/test_graph.py @@ -42,6 +42,7 @@ import tests.resources.display_name_list_functions import tests.resources.dummy_functions import tests.resources.dummy_functions_module_override +import tests.resources.dynamic_parallelism.parallel_collect_multiple_arguments import tests.resources.extract_column_nodes import tests.resources.extract_columns_execution_count import tests.resources.functions_with_generics @@ -1198,6 +1199,32 @@ def test_function_graph_display_config_node(): assert any(line.startswith("\tX") for line in dot.body) +def test_function_graph_display_collect_edges_only_style_collected_dependency(): + config = {} + fg = graph.FunctionGraph.from_modules( + tests.resources.dynamic_parallelism.parallel_collect_multiple_arguments, config=config + ) + + assert fg.nodes["summed"].node_role == NodeType.COLLECT + assert fg.nodes["summed"].collect_dependency == "double" + + dot = fg.display(set(fg.get_nodes()), output_file_path=None, config=config) + + edge_lines = { + line.strip().split(" [", 1)[0]: line.strip() for line in dot.body if " -> " in line + } + assert edge_lines["double -> summed"] == "double -> summed [arrowtail=crow dir=both]" + assert edge_lines["not_to_repeat -> summed"] == "not_to_repeat -> summed" + assert ( + edge_lines["something_else_not_to_repeat -> summed"] + == "something_else_not_to_repeat -> summed" + ) + assert ( + edge_lines["number_to_repeat -> double"] + == "number_to_repeat -> double [arrowhead=crow arrowtail=none dir=both]" + ) + + # TODO use high-level visualization dot as fixtures for reuse across tests def test_display_config_node(): """Check if config is displayed by high-level hamilton.driver.display..."""