Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion hamilton/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -437,6 +437,13 @@ def _get_edge_style(from_type: str, to_type: str) -> dict:

return edge_style

def _is_collect_dependency_edge(dependency_node: node.Node, target_node: node.Node) -> bool:
"""Returns true for the edge that feeds a Collect[...] dependency."""
return (
target_node.node_role == node.NodeType.COLLECT
and dependency_node.name == target_node.collect_dependency
)

def _get_legend(
node_types: set[str], extra_legend_nodes: dict[tuple[str, str], dict[str, str]]
):
Expand Down Expand Up @@ -648,7 +655,6 @@ def _insert_space_after_colon(col: str) -> str:
# create edges
input_sets = dict()
for n in nodes:
to_type = "collect" if n.node_role == node.NodeType.COLLECT else ""
to_modifiers = node_modifiers.get(n.name, set())

input_nodes = set()
Expand All @@ -664,6 +670,7 @@ def _insert_space_after_colon(col: str) -> str:
continue

from_type = "expand" if d.node_role == node.NodeType.EXPAND else ""
to_type = "collect" if _is_collect_dependency_edge(d, n) else ""
dependency_modifiers = node_modifiers.get(d.name, set())
edge_style = _get_edge_style(from_type, to_type)
if (
Expand Down
27 changes: 27 additions & 0 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import tests.resources.display_name_list_functions
import tests.resources.dummy_functions
import tests.resources.dummy_functions_module_override
import tests.resources.dynamic_parallelism.parallel_collect_multiple_arguments
import tests.resources.extract_column_nodes
import tests.resources.extract_columns_execution_count
import tests.resources.functions_with_generics
Expand Down Expand Up @@ -1198,6 +1199,32 @@ def test_function_graph_display_config_node():
assert any(line.startswith("\tX") for line in dot.body)


def test_function_graph_display_collect_edges_only_style_collected_dependency():
config = {}
fg = graph.FunctionGraph.from_modules(
tests.resources.dynamic_parallelism.parallel_collect_multiple_arguments, config=config
)

assert fg.nodes["summed"].node_role == NodeType.COLLECT
assert fg.nodes["summed"].collect_dependency == "double"

dot = fg.display(set(fg.get_nodes()), output_file_path=None, config=config)

edge_lines = {
line.strip().split(" [", 1)[0]: line.strip() for line in dot.body if " -> " in line
}
assert edge_lines["double -> summed"] == "double -> summed [arrowtail=crow dir=both]"
assert edge_lines["not_to_repeat -> summed"] == "not_to_repeat -> summed"
assert (
edge_lines["something_else_not_to_repeat -> summed"]
== "something_else_not_to_repeat -> summed"
)
assert (
edge_lines["number_to_repeat -> double"]
== "number_to_repeat -> double [arrowhead=crow arrowtail=none dir=both]"
)


# TODO use high-level visualization dot as fixtures for reuse across tests
def test_display_config_node():
"""Check if config is displayed by high-level hamilton.driver.display..."""
Expand Down
Loading