diff --git a/hamilton/function_modifiers/macros.py b/hamilton/function_modifiers/macros.py index 3258e83f2..7aa9cc6c4 100644 --- a/hamilton/function_modifiers/macros.py +++ b/hamilton/function_modifiers/macros.py @@ -493,26 +493,27 @@ def named(self, name: str, namespace: NamespaceType = ...) -> "Applicable": # on_input / on_output are the same but here for naming convention # I know there is a way to dynamically resolve this to revert to a common function # just can't remember it now or find it online... - # TODO: adding the option to select target parameter for each transform - # def on_input(self, target: base.TargetType) -> "Applicable": - # """Add Target on a single function level. - - # This determines to which node(s) it will applies. Should match the same naming convention - # as the NodeTransorfmLifecycle child class (for example NodeTransformer). - - # :param target: Which node(s) to apply on top of - # :return: The Applicable with specified target - # """ - # return Applicable( - # fn=self.fn, - # _resolvers=self.resolvers, - # _name=self.name, - # _namespace=self.namespace, - # _target=target if target is not None else self.target, - # args=self.args, - # kwargs=self.kwargs, - # target_fn=self.target_fn, - # ) + def on_input(self, target: base.TargetType) -> "Applicable": + """Add Target on a single function level. + + This determines to which node(s) it will applies. Should match the same naming convention + as the NodeTransorfmLifecycle child class (for example NodeTransformer). + + :param target: Which node(s) to apply on top of + :return: The Applicable with specified target + """ + base.NodeTransformer._early_validate_target(target=target, allow_multiple=True) + pipe_input._validate_on_input_target(target=target) + return Applicable( + fn=self.fn, + _resolvers=self.resolvers, + _name=self.name, + _namespace=self.namespace, + _target=target if target is not None else self.target, + args=self.args, + kwargs=self.kwargs, + target_fn=self.target_fn, + ) def on_output(self, target: base.TargetType) -> "Applicable": """Add Target on a single function level. @@ -830,8 +831,7 @@ def final_result(upstream_int: int) -> int: consumption/output later. Setting the namespace in individual nodes as well as in ``pipe_input`` is not yet supported. 3. ``on_input`` -- this selects which input we will run the pipeline on. - In case ``on_input`` is set to None (default), we apply ``pipe_input`` on the first parameter. Let us know if you wish to expand to other use-cases. - You can track the progress on this topic via: https://github.com/apache/hamilton/issues/1177 + In case ``on_input`` is set to None (default), we apply ``pipe_input`` on the first parameter. The following would apply function *_add_one* and *_add_two* to ``p2``: @@ -845,39 +845,22 @@ def final_result(upstream_int: int) -> int: def final_result(p1: int, p2: int, p3: int) -> int: return upstream_int - .. | - THIS IS COMMENTED OUT, I.E. SPHINX WILL NOT AUTODOC IT, HERE IN CASE WE ENABLE MULTIPLE PARAMETER TARGETS - For extra control in case of multiple function arguments (parameters), we can also specify the target parameter that we wish to transform. - In case ``on_input`` is set to None (default), we apply ``pipe_input`` on the first parameter only. If ``on_input`` is set for a specific transform - make sure the other ones are also set either through a global setting or individually, otherwise it is unclear which transforms target which parameters. + For extra control in case of multiple function arguments (parameters), we can specify target parameters globally + with ``on_input=["p1", "p3"]`` or locally with ``step(...).on_input("p2")``. If a local ``on_input`` is used + without a global ``on_input``, every step needs a local target so the target distribution is unambiguous. - The following applies *_add_one* to ``p1``, ``p3`` and *_add_two* to ``p2`` - - .. code-block:: python - - @pipe_input( - step(_add_one).on_input(["p1","p3"]) - step(_add_two, y=source("upstream_node")).on_input("p2") - ) - def final_result(p1: int, p2: int, p3: int) -> int: - return p1 + p2 + p3 + Lastly, a mixture of global and local is possible. The global target applies to all transforms, and local targets + add more parameters for individual transforms. - We can also do this on the global level to set for all transforms a target parameter. - - Lastly, a mixture of global and local is possible, where the global selects the target parameters for - all transforms and we can select individual transforms to also target more parameters. - The following would apply function *_add_one* to all ``p1``, ``p2``, ``p3`` and *_add_two* also on ``p2`` - - .. code-block:: python + .. code-block:: python - @pipe_input( - step(_add_one).on_input(["p1","p3"]) - step(_add_two, y=source("upstream_node")), - on_input = "p2" - ) - def final_result(p1: int, p2: int, p3: int) -> int: - return upstream_int - | replace:: \ + @pipe_input( + step(_add_one).on_input(["p1", "p3"]), + step(_add_two, y=source("upstream_node")), + on_input="p2", + ) + def final_result(p1: int, p2: int, p3: int) -> int: + return p1 + p2 + p3 """ def __init__( @@ -896,27 +879,14 @@ def __init__( :param collapse: Whether to collapse this into a single node. This is not currently supported. :param _chain: Whether to chain the first parameter. This is the only mode that is supported. Furthermore, this is not externally exposed. ``@flow`` will make use of this. """ - if on_input is not None: - if not isinstance(on_input, str): - raise NotImplementedError( - "on_input currently only supports a single target parameter specified by a string. " - "Please reach out if you want a more flexible option in the feature." - ) base.NodeTransformer._early_validate_target(target=on_input, allow_multiple=True) + self._validate_on_input_target(target=on_input) self.transforms = transforms self.collapse = collapse self.chain = _chain self.namespace = namespace - self.target = [on_input] - - # TODO: for multiple target parameter case - # if isinstance(on_input, str): # have to do extra since strings are collections in python - # self.target = [on_input] - # elif isinstance(on_input, Collection): - # self.target = on_input - # else: - # self.target = [on_input] + self.target = self._normalize_targets(on_input) if self.collapse: raise NotImplementedError( @@ -926,6 +896,29 @@ def __init__( if self.chain: raise NotImplementedError("@flow() is not yet supported -- this is ") + @staticmethod + def _normalize_targets(target: base.TargetType) -> list[str | None | EllipsisType]: + if isinstance(target, str): # have to do extra since strings are collections in python + return [target] + elif isinstance(target, Collection): + return list(target) + else: + return [target] + + @staticmethod + def _validate_on_input_target(target: base.TargetType): + if target is ...: + raise base.InvalidDecoratorException( + "Cannot apply Ellipsis(...) to on_input. Use None, a string, or a non-empty collection of strings." + ) + if isinstance(target, Collection) and not isinstance(target, str) and len(target) == 0: + raise base.InvalidDecoratorException( + "Cannot apply an empty collection to on_input. Use None, a string, or a non-empty collection of strings." + ) + + def _explicit_global_targets(self) -> list[str]: + return [target for target in self.target if isinstance(target, str)] + def _distribute_transforms_to_parameters( self, params: dict[str, type[type]] ) -> dict[str, list[Applicable]]: @@ -941,23 +934,21 @@ def _distribute_transforms_to_parameters( """ selected_transforms = defaultdict(list) + global_targets = self._explicit_global_targets() for param in params: - if param in self.target: - selected_transforms[param].extend(self.transforms) - # TODO: in case of multiple parameters we can set individual targets and resolve them here - # for transform in self.transforms: - # target = transform.target - # # In case there is no target set on applicable we assign global target - # if target is None: - # target = self.target - # elif isinstance(target, str): # user selects single target via string - # target = [target] - # target.extend(self.target) - # elif isinstance(target, Collection): # user inputs a list of targets - # target.extend(self.target) - - # if param in target: - # selected_transforms[param].append(transform) + for transform in self.transforms: + if transform.target is None: + target = global_targets + else: + target = [ + item + for item in self._normalize_targets(transform.target) + if isinstance(item, str) + ] + target.extend(global_targets) + + if param in target: + selected_transforms[param].append(transform) return selected_transforms @@ -975,6 +966,23 @@ def _create_valid_parameters_transforms_mapping( # if not, skip that, pointing to the previous # Create a node along the way + explicit_target_names = set(self._explicit_global_targets()) + for transform in self.transforms: + if transform.target is not None: + explicit_target_names.update( + item + for item in self._normalize_targets(transform.target) + if isinstance(item, str) + ) + + invalid_targets = explicit_target_names.difference(param_names) + if invalid_targets: + raise base.InvalidDecoratorException( + f"Function: {fn.__name__} with parameters {param_names} does not have " + f"dependency/ies {sorted(invalid_targets)}. @pipe_input requires target " + "parameter names to match the function parameters." + ) + if not mapping: # This reverts back to legacy chaining through first parameter and checks first parameter first_parameter = param_names[0] @@ -985,28 +993,27 @@ def _create_valid_parameters_transforms_mapping( f"Thus it might not be compatible with some other decorators" ) mapping[first_parameter] = self.transforms - # TODO: validate that all transforms have a target in case multiple parameters targeted - # else: - # # in case we set target this checks that each transform has at least one target parameter - # transform_set = [] - # for param in mapping: - # transform_set.extend(mapping[param]) - # transform_set = set(transform_set) - # if len(transform_set) != len(self.transforms): - # raise MissingTargetError( - # "The on_input settings are unclear. Please make sure all transforms " - # "either have specified individually or globally a target or there is " - # "no on_input usage." - # ) - - # similar to above we check that the target parameter is among the actual function parameters - if next(iter(mapping)) not in param_names: - raise base.InvalidDecoratorException( - f"Function: {fn.__name__} with parameters {param_names} does not a have " - f"dependency {next(iter(mapping))}. @pipe_input requires the parameter " - f"names to match the function parameters. Thus it might not be compatible " - f"with some other decorators." - ) + elif not self._explicit_global_targets(): + # If there are local targets but no global target, make sure every transform has + # a target so partially targeted chains do not silently fall back to legacy behavior. + transform_set = [] + for param in mapping: + transform_set.extend(mapping[param]) + if len(set(transform_set)) != len(self.transforms): + raise base.InvalidDecoratorException( + "The on_input settings are unclear. Please make sure all transforms " + "either have a target specified individually or globally, or that no " + "on_input targets are used." + ) + + for target_parameter in mapping: + if target_parameter not in param_names: + raise base.InvalidDecoratorException( + f"Function: {fn.__name__} with parameters {param_names} does not a have " + f"dependency {target_parameter}. @pipe_input requires the parameter " + f"names to match the function parameters. Thus it might not be compatible " + f"with some other decorators." + ) return mapping diff --git a/tests/function_modifiers/test_macros.py b/tests/function_modifiers/test_macros.py index f8da2a234..577f4e01e 100644 --- a/tests/function_modifiers/test_macros.py +++ b/tests/function_modifiers/test_macros.py @@ -348,20 +348,21 @@ def function_multiple_same_type_params(p1: int, p2: int, p3: int) -> int: return p1 + p2 + p3 -# TODO: in case of multiple paramters need some type checking -# def function_multiple_diverse_type_params(p1: int, p2: str, p3: int) -> int: -# return p1 + len(p2) + p3 - +def test_pipe_input_mapping_args_targets_global_list(): + n = node.Node.from_fn(function_multiple_same_type_params) -def test_pipe_input_on_input_error_unless_string_or_none(): - with pytest.raises(NotImplementedError): - decorator = pipe_input( # noqa - step(_test_apply_function, source("bar_upstream"), baz=value(10)).named("node_1"), - step(_test_apply_function, source("bar_upstream"), baz=value(100)).named("node_2"), - step(_test_apply_function, source("bar_upstream"), baz=value(1000)).named("node_3"), - on_input=["p2", "p3"], - namespace="abc", - ) + decorator = pipe_input( + step(_test_apply_function, source("bar_upstream"), baz=value(10)).named("node_1"), + step(_test_apply_function, source("bar_upstream"), baz=value(100)).named("node_2"), + on_input=["p2", "p3"], + namespace="abc", + ) + nodes = decorator.transform_dag([n], {}, function_multiple_same_type_params) + nodes_by_name = {item.name: item for item in nodes} + assert nodes_by_name["abc_p2.node_1"](p2=1, bar_upstream=3) == 14 + assert nodes_by_name["abc_p2.node_2"](**{"abc_p2.node_1": 2, "bar_upstream": 3}) == 105 + assert nodes_by_name["abc_p3.node_1"](p3=7, bar_upstream=3) == 20 + assert nodes_by_name["abc_p3.node_2"](**{"abc_p3.node_1": 11, "bar_upstream": 3}) == 114 def test_pipe_input_mapping_args_targets_global(): @@ -380,204 +381,118 @@ def test_pipe_input_mapping_args_targets_global(): assert chain_node(p2=1, bar_upstream=3) == 14 -# TODO: multiple parameter tests -# def test_pipe_input_no_namespace_with_target(): -# n = node.Node.from_fn(function_multiple_diverse_type_params) - -# decorator = pipe_input( -# step(_test_apply_function, source("bar_upstream"), baz=value(10)) -# .on_input(["p1", "p3"]) -# .named("node_1"), -# step(_test_apply_function, source("bar_upstream"), baz=value(100)).named("node_2"), -# step(_test_apply_function, source("bar_upstream"), baz=value(1000)) -# .on_input("p3") -# .named("node_3"), -# on_input="p2", -# namespace=None, -# ) -# nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) -# final_node = nodes[0].name -# p1_node = nodes[1].name -# p2_node1 = nodes[2].name -# p2_node2 = nodes[3].name -# p2_node3 = nodes[4].name -# p3_node1 = nodes[5].name -# p3_node2 = nodes[6].name - -# assert final_node == "function_multiple_diverse_type_params" -# assert p1_node == "p1.node_1" -# assert p2_node1 == "p2.node_1" -# assert p2_node2 == "p2.node_2" -# assert p2_node3 == "p2.node_3" -# assert p3_node1 == "p3.node_1" -# assert p3_node2 == "p3.node_3" - - -# def test_pipe_input_elipsis_namespace_with_target(): -# n = node.Node.from_fn(function_multiple_diverse_type_params) - -# decorator = pipe_input( -# step(_test_apply_function, source("bar_upstream"), baz=value(10)) -# .on_input(["p1", "p3"]) -# .named("node_1"), -# step(_test_apply_function, source("bar_upstream"), baz=value(100)).named("node_2"), -# step(_test_apply_function, source("bar_upstream"), baz=value(1000)) -# .on_input("p3") -# .named("node_3"), -# namespace=..., -# on_input="p2", -# ) -# nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) -# final_node = nodes[0].name -# p1_node = nodes[1].name -# p2_node1 = nodes[2].name -# p2_node2 = nodes[3].name -# p2_node3 = nodes[4].name -# p3_node1 = nodes[5].name -# p3_node2 = nodes[6].name - -# assert final_node == "function_multiple_diverse_type_params" -# assert p1_node == "p1.node_1" -# assert p2_node1 == "p2.node_1" -# assert p2_node2 == "p2.node_2" -# assert p2_node3 == "p2.node_3" -# assert p3_node1 == "p3.node_1" -# assert p3_node2 == "p3.node_3" - - -# def test_pipe_input_custom_namespace_with_target(): -# n = node.Node.from_fn(function_multiple_diverse_type_params) - -# decorator = pipe_input( -# step(_test_apply_function, source("bar_upstream"), baz=value(10)) -# .on_input(["p1", "p3"]) -# .named("node_1"), -# step(_test_apply_function, source("bar_upstream"), baz=value(100)).named("node_2"), -# step(_test_apply_function, source("bar_upstream"), baz=value(1000)) -# .on_input("p3") -# .named("node_3"), -# namespace="abc", -# on_input="p2", -# ) -# nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) -# final_node = nodes[0].name -# p1_node = nodes[1].name -# p2_node1 = nodes[2].name -# p2_node2 = nodes[3].name -# p2_node3 = nodes[4].name -# p3_node1 = nodes[5].name -# p3_node2 = nodes[6].name - -# assert final_node == "function_multiple_diverse_type_params" -# assert p1_node == "abc_p1.node_1" -# assert p2_node1 == "abc_p2.node_1" -# assert p2_node2 == "abc_p2.node_2" -# assert p2_node3 == "abc_p2.node_3" -# assert p3_node1 == "abc_p3.node_1" -# assert p3_node2 == "abc_p3.node_3" - - -# def test_pipe_input_mapping_args_targets_local(): -# n = node.Node.from_fn(function_multiple_diverse_type_params) - -# decorator = pipe_input( -# step(_test_apply_function, source("bar_upstream"), baz=value(10)) -# .on_input(["p1", "p3"]) -# .named("node_1"), -# step(_test_apply_function, source("bar_upstream"), baz=value(100)) -# .on_input("p2") -# .named("node_2"), -# step(_test_apply_function, source("bar_upstream"), baz=value(1000)) -# .on_input("p3") -# .named("node_3"), -# namespace="abc", -# ) -# nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) -# nodes_by_name = {item.name: item for item in nodes} -# chain_node_1 = nodes_by_name["abc_p1.node_1"] -# chain_node_2 = nodes_by_name["abc_p2.node_2"] -# chain_node_3_first = nodes_by_name["abc_p3.node_1"] -# assert chain_node_1(p1=1, bar_upstream=3) == 14 -# assert chain_node_2(p2=1, bar_upstream=3) == 104 -# assert chain_node_3_first(p3=7, bar_upstream=3) == 20 -# -# -# def test_pipe_input_mapping_args_targets_local_adds_to_global(): -# n = node.Node.from_fn(function_multiple_same_type_params) - -# decorator = pipe_input( -# step(_test_apply_function, source("bar_upstream"), baz=value(10)) -# .on_input(["p1", "p2"]) -# .named("node_1"), -# step(_test_apply_function, source("bar_upstream"), baz=value(100)) -# .on_input("p2") -# .named("node_2"), -# step(_test_apply_function, source("bar_upstream"), baz=value(1000)).named("node_3"), -# on_input="p3", -# namespace="abc", -# ) -# nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) -# nodes_by_name = {item.name: item for item in nodes} -# p1_node = nodes_by_name["abc_p1.node_1"] -# p2_node1 = nodes_by_name["abc_p2.node_1"] -# p2_node2 = nodes_by_name["abc_p2.node_2"] -# p3_node1 = nodes_by_name["abc_p3.node_1"] -# p3_node2 = nodes_by_name["abc_p3.node_2"] -# p3_node3 = nodes_by_name["abc_p3.node_3"] - -# assert p1_node(p1=1, bar_upstream=3) == 14 -# assert p2_node1(p2=7, bar_upstream=3) == 20 -# assert p2_node2(**{"abc_p2.node_1": 2, "bar_upstream": 3}) == 105 -# assert p3_node1(p3=9, bar_upstream=3) == 22 -# assert p3_node2(**{"abc_p3.node_1": 13, "bar_upstream": 3}) == 116 -# assert p3_node3(**{"abc_p3.node_2": 17, "bar_upstream": 3}) == 1020 - - -# def test_pipe_input_fails_with_missing_targets(): -# n = node.Node.from_fn(function_multiple_same_type_params) - -# decorator = pipe_input( -# step(_test_apply_function, source("bar_upstream"), baz=value(10)) -# .on_input(["p1", "p2"]) -# .named("node_1"), -# step(_test_apply_function, source("bar_upstream"), baz=value(100)) -# .on_input("p2") -# .named("node_2"), -# step(_test_apply_function, source("bar_upstream"), baz=value(1000)).named("node_3"), -# namespace="abc", -# ) -# with pytest.raises(hamilton.function_modifiers.macros.MissingTargetError): -# nodes = decorator.transform_dag([n], {}, function_multiple_same_type_params) # noqa - - -# def test_pipe_input_decorator_with_target_no_collapse_multi_node(): -# n = node.Node.from_fn(function_multiple_same_type_params) - -# decorator = pipe_input( -# step(_test_apply_function, source("bar_upstream"), baz=value(10)) -# .on_input(["p1", "p3"]) -# .named("node_1"), -# step(_test_apply_function, source("bar_upstream"), baz=value(100)) -# .on_input("p2") -# .named("node_2"), -# step(_test_apply_function, source("bar_upstream"), baz=value(1000)) -# .on_input("p3") -# .named("node_3"), -# namespace="abc", -# ) -# nodes = decorator.transform_dag([n], {}, function_multiple_diverse_type_params) -# nodes_by_name = {item.name: item for item in nodes} -# final_node = nodes_by_name["function_multiple_same_type_params"] -# chain_node_1 = nodes_by_name["abc_p1.node_1"] -# chain_node_2 = nodes_by_name["abc_p2.node_2"] -# chain_node_3_first = nodes_by_name["abc_p3.node_1"] -# chain_node_3_second = nodes_by_name["abc_p3.node_3"] -# assert len(nodes_by_name) == 5 -# assert chain_node_1(p1=1, bar_upstream=3) == 14 -# assert chain_node_2(p2=1, bar_upstream=3) == 104 -# assert chain_node_3_first(p3=7, bar_upstream=3) == 20 -# assert chain_node_3_second(**{"abc_p3.node_1": 13, "bar_upstream": 3}) == 1016 -# assert final_node(**{"abc_p1.node_1": 3, "abc_p2.node_2": 4, "abc_p3.node_3": 5}) == 12 +def test_pipe_input_mapping_args_targets_local(): + n = node.Node.from_fn(function_multiple_same_type_params) + + decorator = pipe_input( + step(_test_apply_function, source("bar_upstream"), baz=value(10)) + .on_input(["p1", "p3"]) + .named("node_1"), + step(_test_apply_function, source("bar_upstream"), baz=value(100)) + .on_input("p2") + .named("node_2"), + step(_test_apply_function, source("bar_upstream"), baz=value(1000)) + .on_input("p3") + .named("node_3"), + namespace="abc", + ) + nodes = decorator.transform_dag([n], {}, function_multiple_same_type_params) + nodes_by_name = {item.name: item for item in nodes} + assert nodes_by_name["abc_p1.node_1"](p1=1, bar_upstream=3) == 14 + assert nodes_by_name["abc_p2.node_2"](p2=1, bar_upstream=3) == 104 + assert nodes_by_name["abc_p3.node_1"](p3=7, bar_upstream=3) == 20 + assert nodes_by_name["abc_p3.node_3"](**{"abc_p3.node_1": 13, "bar_upstream": 3}) == 1016 + + +def test_pipe_input_mapping_args_targets_local_adds_to_global(): + n = node.Node.from_fn(function_multiple_same_type_params) + + decorator = pipe_input( + step(_test_apply_function, source("bar_upstream"), baz=value(10)) + .on_input(["p1", "p2"]) + .named("node_1"), + step(_test_apply_function, source("bar_upstream"), baz=value(100)) + .on_input("p2") + .named("node_2"), + step(_test_apply_function, source("bar_upstream"), baz=value(1000)).named("node_3"), + on_input="p3", + namespace="abc", + ) + nodes = decorator.transform_dag([n], {}, function_multiple_same_type_params) + nodes_by_name = {item.name: item for item in nodes} + + assert nodes_by_name["abc_p1.node_1"](p1=1, bar_upstream=3) == 14 + assert nodes_by_name["abc_p2.node_1"](p2=7, bar_upstream=3) == 20 + assert nodes_by_name["abc_p2.node_2"](**{"abc_p2.node_1": 2, "bar_upstream": 3}) == 105 + assert nodes_by_name["abc_p3.node_1"](p3=9, bar_upstream=3) == 22 + assert nodes_by_name["abc_p3.node_2"](**{"abc_p3.node_1": 13, "bar_upstream": 3}) == 116 + assert nodes_by_name["abc_p3.node_3"](**{"abc_p3.node_2": 17, "bar_upstream": 3}) == 1020 + + +def test_pipe_input_multi_target_default_namespace_uses_parameter_names(): + n = node.Node.from_fn(function_multiple_same_type_params) + + decorator = pipe_input( + step(_test_apply_function, source("bar_upstream"), baz=value(10)).named("node_1"), + on_input=["p1", "p3"], + ) + nodes = decorator.transform_dag([n], {}, function_multiple_same_type_params) + nodes_by_name = {item.name: item for item in nodes} + + assert "p1.node_1" in nodes_by_name + assert "p3.node_1" in nodes_by_name + assert nodes_by_name["p1.node_1"](p1=1, bar_upstream=3) == 14 + assert nodes_by_name["p3.node_1"](p3=7, bar_upstream=3) == 20 + + +def test_pipe_input_fails_with_missing_targets(): + n = node.Node.from_fn(function_multiple_same_type_params) + + decorator = pipe_input( + step(_test_apply_function, source("bar_upstream"), baz=value(10)) + .on_input(["p1", "p2"]) + .named("node_1"), + step(_test_apply_function, source("bar_upstream"), baz=value(100)) + .on_input("p2") + .named("node_2"), + step(_test_apply_function, source("bar_upstream"), baz=value(1000)).named("node_3"), + namespace="abc", + ) + with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException): + decorator.transform_dag([n], {}, function_multiple_same_type_params) + + +def test_pipe_input_fails_with_unknown_target(): + n = node.Node.from_fn(function_multiple_same_type_params) + + decorator = pipe_input( + step(_test_apply_function, source("bar_upstream"), baz=value(10)).named("node_1"), + on_input=["p2", "missing"], + namespace="abc", + ) + with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException): + decorator.transform_dag([n], {}, function_multiple_same_type_params) + + +@pytest.mark.parametrize("target", [..., []]) +def test_pipe_input_fails_with_invalid_global_target(target): + with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException): + pipe_input( + step(_test_apply_function, source("bar_upstream"), baz=value(10)).named("node_1"), + on_input=target, + ) + + +@pytest.mark.parametrize("target", [1, ["p1", 1]]) +def test_pipe_input_fails_with_invalid_local_target_type(target): + with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException): + step(_test_apply_function, source("bar_upstream"), baz=value(10)).on_input(target) + + +@pytest.mark.parametrize("target", [..., []]) +def test_pipe_input_fails_with_invalid_local_target(target): + with pytest.raises(hamilton.function_modifiers.base.InvalidDecoratorException): + step(_test_apply_function, source("bar_upstream"), baz=value(10)).on_input(target) def test_pipe_decorator_positional_variable_args(): @@ -713,58 +628,58 @@ def test_pipe_end_to_end_target_global(): ) -# TODO: For multiple parameters end-to-end -# def test_pipe_end_to_end_target_local(): -# dr = ( -# driver.Builder() -# .with_modules(tests.resources.pipe_input) -# .with_adapter(base.DefaultAdapter()) -# .with_config({"calc_c": True}) -# .build() -# ) - -# inputs = { -# "input_1": 10, -# "input_2": 20, -# "input_3": 30, -# } -# result = dr.execute( -# [ -# "chain_1_using_pipe_input_target_local", -# "chain_1_not_using_pipe_input_target_local", -# ], -# inputs=inputs, -# ) -# assert ( -# result["chain_1_not_using_pipe_input_target_local"] -# == result["chain_1_using_pipe_input_target_local"] -# ) - -# def test_pipe_end_to_end_target_mixed(): -# dr = ( -# driver.Builder() -# .with_modules(tests.resources.pipe_input) -# .with_adapter(base.DefaultAdapter()) -# .with_config({"calc_c": True}) -# .build() -# ) - -# inputs = { -# "input_1": 10, -# "input_2": 20, -# "input_3": 30, -# } -# result = dr.execute( -# [ -# "chain_1_using_pipe_input_target_mixed", -# "chain_1_not_using_pipe_input_target_mixed", -# ], -# inputs=inputs, -# ) -# assert ( -# result["chain_1_not_using_pipe_input_target_mixed"] -# == result["chain_1_using_pipe_input_target_mixed"] -# ) +def test_pipe_end_to_end_target_local(): + dr = ( + driver.Builder() + .with_modules(tests.resources.pipe_input) + .with_adapter(base.DefaultAdapter()) + .with_config({"calc_c": True}) + .build() + ) + + inputs = { + "input_1": 10, + "input_2": 20, + "input_3": 30, + } + result = dr.execute( + [ + "chain_1_using_pipe_input_target_local", + "chain_1_not_using_pipe_input_target_local", + ], + inputs=inputs, + ) + assert ( + result["chain_1_not_using_pipe_input_target_local"] + == result["chain_1_using_pipe_input_target_local"] + ) + + +def test_pipe_end_to_end_target_mixed(): + dr = ( + driver.Builder() + .with_modules(tests.resources.pipe_input) + .with_adapter(base.DefaultAdapter()) + .with_config({"calc_c": True}) + .build() + ) + + inputs = { + "input_1": 10, + "input_2": 20, + "input_3": 30, + } + result = dr.execute( + [ + "chain_1_using_pipe_input_target_mixed", + "chain_1_not_using_pipe_input_target_mixed", + ], + inputs=inputs, + ) + assert ( + result["chain_1_not_using_pipe_input_target_mixed"] + == result["chain_1_using_pipe_input_target_mixed"] + ) def result_from_downstream_function() -> int: diff --git a/tests/resources/pipe_input.py b/tests/resources/pipe_input.py index 127808fb4..945c6b41f 100644 --- a/tests/resources/pipe_input.py +++ b/tests/resources/pipe_input.py @@ -110,53 +110,52 @@ def chain_1_not_using_pipe_input_target_global( return v + e * 10 -# TODO: for tests in case of multiple target parameters -# @pipe_input( -# step(_add_one).on_input(["v", "w"]).named("a"), -# step(_add_two).on_input("v").named("b"), -# step(_add_n, n=3).on_input("w").named("c").when(calc_c=True), -# step(_add_n, n=source("input_1")).on_input("v").named("d"), -# step(_multiply_n, n=source("input_2")).on_input("w").named("e"), -# namespace="local", -# ) -# def chain_1_using_pipe_input_target_local(v: int, w: int) -> int: -# return v + w * 10 - - -# def chain_1_not_using_pipe_input_target_local( -# v: int, w: int, input_1: int, input_2: int, calc_c: bool = False -# ) -> int: -# av = _add_one(v) -# aw = _add_one(w) -# bv = _add_two(av) -# cw = _add_n(aw, n=3) if calc_c else aw -# dv = _add_n(bv, n=input_1) -# ew = _multiply_n(cw, n=input_2) -# return dv + ew * 10 - - -# @pipe_input( -# step(_add_one).on_input("w").named("a"), -# step(_add_two).named("b"), -# step(_add_n, n=3).on_input("w").named("c").when(calc_c=True), -# step(_add_n, n=source("input_1")).named("d"), -# step(_multiply_n, n=source("input_2")).on_input("w").named("e"), -# namespace="mixed", -# on_input="v", -# ) -# def chain_1_using_pipe_input_target_mixed(v: int, w: int) -> int: -# return v + w * 10 - - -# def chain_1_not_using_pipe_input_target_mixed( -# v: int, w: int, input_1: int, input_2: int, calc_c: bool = False -# ) -> int: -# av = _add_one(v) -# aw = _add_one(w) -# bv = _add_two(av) -# cv = _add_n(bv, n=3) if calc_c else bv -# cw = _add_n(aw, n=3) if calc_c else aw -# dv = _add_n(cv, n=input_1) -# ev = _multiply_n(dv, n=input_2) -# ew = _multiply_n(cw, n=input_2) -# return ev + ew * 10 +@pipe_input( + step(_add_one).on_input(["v", "w"]).named("a"), + step(_add_two).on_input("v").named("b"), + step(_add_n, n=3).on_input("w").named("c").when(calc_c=True), + step(_add_n, n=source("input_1")).on_input("v").named("d"), + step(_multiply_n, n=source("input_2")).on_input("w").named("e"), + namespace="local", +) +def chain_1_using_pipe_input_target_local(v: int, w: int) -> int: + return v + w * 10 + + +def chain_1_not_using_pipe_input_target_local( + v: int, w: int, input_1: int, input_2: int, calc_c: bool = False +) -> int: + av = _add_one(v) + aw = _add_one(w) + bv = _add_two(av) + cw = _add_n(aw, n=3) if calc_c else aw + dv = _add_n(bv, n=input_1) + ew = _multiply_n(cw, n=input_2) + return dv + ew * 10 + + +@pipe_input( + step(_add_one).on_input("w").named("a"), + step(_add_two).named("b"), + step(_add_n, n=3).on_input("w").named("c").when(calc_c=True), + step(_add_n, n=source("input_1")).named("d"), + step(_multiply_n, n=source("input_2")).on_input("w").named("e"), + namespace="mixed", + on_input="v", +) +def chain_1_using_pipe_input_target_mixed(v: int, w: int) -> int: + return v + w * 10 + + +def chain_1_not_using_pipe_input_target_mixed( + v: int, w: int, input_1: int, input_2: int, calc_c: bool = False +) -> int: + av = _add_one(v) + aw = _add_one(w) + bv = _add_two(av) + cv = _add_n(bv, n=3) if calc_c else bv + cw = _add_n(aw, n=3) if calc_c else aw + dv = _add_n(cv, n=input_1) + ev = _multiply_n(dv, n=input_2) + ew = _multiply_n(cw, n=input_2) + return ev + ew * 10