Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
217 changes: 112 additions & 105 deletions hamilton/function_modifiers/macros.py
Original file line number Diff line number Diff line change
Expand Up @@ -493,26 +493,27 @@ def named(self, name: str, namespace: NamespaceType = ...) -> "Applicable":
# on_input / on_output are the same but here for naming convention
# I know there is a way to dynamically resolve this to revert to a common function
# just can't remember it now or find it online...
# TODO: adding the option to select target parameter for each transform
# def on_input(self, target: base.TargetType) -> "Applicable":
# """Add Target on a single function level.

# This determines to which node(s) it will applies. Should match the same naming convention
# as the NodeTransorfmLifecycle child class (for example NodeTransformer).

# :param target: Which node(s) to apply on top of
# :return: The Applicable with specified target
# """
# return Applicable(
# fn=self.fn,
# _resolvers=self.resolvers,
# _name=self.name,
# _namespace=self.namespace,
# _target=target if target is not None else self.target,
# args=self.args,
# kwargs=self.kwargs,
# target_fn=self.target_fn,
# )
def on_input(self, target: base.TargetType) -> "Applicable":
"""Add Target on a single function level.

This determines to which node(s) it will applies. Should match the same naming convention
as the NodeTransorfmLifecycle child class (for example NodeTransformer).

:param target: Which node(s) to apply on top of
:return: The Applicable with specified target
"""
base.NodeTransformer._early_validate_target(target=target, allow_multiple=True)
pipe_input._validate_on_input_target(target=target)
return Applicable(
fn=self.fn,
_resolvers=self.resolvers,
_name=self.name,
_namespace=self.namespace,
_target=target if target is not None else self.target,
args=self.args,
kwargs=self.kwargs,
target_fn=self.target_fn,
)

def on_output(self, target: base.TargetType) -> "Applicable":
"""Add Target on a single function level.
Expand Down Expand Up @@ -830,8 +831,7 @@ def final_result(upstream_int: int) -> int:
consumption/output later. Setting the namespace in individual nodes as well as in ``pipe_input`` is not yet supported.

3. ``on_input`` -- this selects which input we will run the pipeline on.
In case ``on_input`` is set to None (default), we apply ``pipe_input`` on the first parameter. Let us know if you wish to expand to other use-cases.
You can track the progress on this topic via: https://github.com/apache/hamilton/issues/1177
In case ``on_input`` is set to None (default), we apply ``pipe_input`` on the first parameter.

The following would apply function *_add_one* and *_add_two* to ``p2``:

Expand All @@ -845,39 +845,22 @@ def final_result(upstream_int: int) -> int:
def final_result(p1: int, p2: int, p3: int) -> int:
return upstream_int

.. |
THIS IS COMMENTED OUT, I.E. SPHINX WILL NOT AUTODOC IT, HERE IN CASE WE ENABLE MULTIPLE PARAMETER TARGETS
For extra control in case of multiple function arguments (parameters), we can also specify the target parameter that we wish to transform.
In case ``on_input`` is set to None (default), we apply ``pipe_input`` on the first parameter only. If ``on_input`` is set for a specific transform
make sure the other ones are also set either through a global setting or individually, otherwise it is unclear which transforms target which parameters.
For extra control in case of multiple function arguments (parameters), we can specify target parameters globally
with ``on_input=["p1", "p3"]`` or locally with ``step(...).on_input("p2")``. If a local ``on_input`` is used
without a global ``on_input``, every step needs a local target so the target distribution is unambiguous.

The following applies *_add_one* to ``p1``, ``p3`` and *_add_two* to ``p2``

.. code-block:: python

@pipe_input(
step(_add_one).on_input(["p1","p3"])
step(_add_two, y=source("upstream_node")).on_input("p2")
)
def final_result(p1: int, p2: int, p3: int) -> int:
return p1 + p2 + p3
Lastly, a mixture of global and local is possible. The global target applies to all transforms, and local targets
add more parameters for individual transforms.

We can also do this on the global level to set for all transforms a target parameter.

Lastly, a mixture of global and local is possible, where the global selects the target parameters for
all transforms and we can select individual transforms to also target more parameters.
The following would apply function *_add_one* to all ``p1``, ``p2``, ``p3`` and *_add_two* also on ``p2``

.. code-block:: python
.. code-block:: python

@pipe_input(
step(_add_one).on_input(["p1","p3"])
step(_add_two, y=source("upstream_node")),
on_input = "p2"
)
def final_result(p1: int, p2: int, p3: int) -> int:
return upstream_int
| replace:: \
@pipe_input(
step(_add_one).on_input(["p1", "p3"]),
step(_add_two, y=source("upstream_node")),
on_input="p2",
)
def final_result(p1: int, p2: int, p3: int) -> int:
return p1 + p2 + p3
"""

def __init__(
Expand All @@ -896,27 +879,14 @@ def __init__(
:param collapse: Whether to collapse this into a single node. This is not currently supported.
:param _chain: Whether to chain the first parameter. This is the only mode that is supported. Furthermore, this is not externally exposed. ``@flow`` will make use of this.
"""
if on_input is not None:
if not isinstance(on_input, str):
raise NotImplementedError(
"on_input currently only supports a single target parameter specified by a string. "
"Please reach out if you want a more flexible option in the feature."
)
base.NodeTransformer._early_validate_target(target=on_input, allow_multiple=True)
self._validate_on_input_target(target=on_input)

self.transforms = transforms
self.collapse = collapse
self.chain = _chain
self.namespace = namespace
self.target = [on_input]

# TODO: for multiple target parameter case
# if isinstance(on_input, str): # have to do extra since strings are collections in python
# self.target = [on_input]
# elif isinstance(on_input, Collection):
# self.target = on_input
# else:
# self.target = [on_input]
self.target = self._normalize_targets(on_input)

if self.collapse:
raise NotImplementedError(
Expand All @@ -926,6 +896,29 @@ def __init__(
if self.chain:
raise NotImplementedError("@flow() is not yet supported -- this is ")

@staticmethod
def _normalize_targets(target: base.TargetType) -> list[str | None | EllipsisType]:
if isinstance(target, str): # have to do extra since strings are collections in python
return [target]
elif isinstance(target, Collection):
return list(target)
else:
return [target]

@staticmethod
def _validate_on_input_target(target: base.TargetType):
if target is ...:
raise base.InvalidDecoratorException(
"Cannot apply Ellipsis(...) to on_input. Use None, a string, or a non-empty collection of strings."
)
if isinstance(target, Collection) and not isinstance(target, str) and len(target) == 0:
raise base.InvalidDecoratorException(
"Cannot apply an empty collection to on_input. Use None, a string, or a non-empty collection of strings."
)

def _explicit_global_targets(self) -> list[str]:
return [target for target in self.target if isinstance(target, str)]

def _distribute_transforms_to_parameters(
self, params: dict[str, type[type]]
) -> dict[str, list[Applicable]]:
Expand All @@ -941,23 +934,21 @@ def _distribute_transforms_to_parameters(
"""

selected_transforms = defaultdict(list)
global_targets = self._explicit_global_targets()
for param in params:
if param in self.target:
selected_transforms[param].extend(self.transforms)
# TODO: in case of multiple parameters we can set individual targets and resolve them here
# for transform in self.transforms:
# target = transform.target
# # In case there is no target set on applicable we assign global target
# if target is None:
# target = self.target
# elif isinstance(target, str): # user selects single target via string
# target = [target]
# target.extend(self.target)
# elif isinstance(target, Collection): # user inputs a list of targets
# target.extend(self.target)

# if param in target:
# selected_transforms[param].append(transform)
for transform in self.transforms:
if transform.target is None:
target = global_targets
else:
target = [
item
for item in self._normalize_targets(transform.target)
if isinstance(item, str)
]
target.extend(global_targets)

if param in target:
selected_transforms[param].append(transform)

return selected_transforms

Expand All @@ -975,6 +966,23 @@ def _create_valid_parameters_transforms_mapping(
# if not, skip that, pointing to the previous
# Create a node along the way

explicit_target_names = set(self._explicit_global_targets())
for transform in self.transforms:
if transform.target is not None:
explicit_target_names.update(
item
for item in self._normalize_targets(transform.target)
if isinstance(item, str)
)

invalid_targets = explicit_target_names.difference(param_names)
if invalid_targets:
raise base.InvalidDecoratorException(
f"Function: {fn.__name__} with parameters {param_names} does not have "
f"dependency/ies {sorted(invalid_targets)}. @pipe_input requires target "
"parameter names to match the function parameters."
)

if not mapping:
# This reverts back to legacy chaining through first parameter and checks first parameter
first_parameter = param_names[0]
Expand All @@ -985,28 +993,27 @@ def _create_valid_parameters_transforms_mapping(
f"Thus it might not be compatible with some other decorators"
)
mapping[first_parameter] = self.transforms
# TODO: validate that all transforms have a target in case multiple parameters targeted
# else:
# # in case we set target this checks that each transform has at least one target parameter
# transform_set = []
# for param in mapping:
# transform_set.extend(mapping[param])
# transform_set = set(transform_set)
# if len(transform_set) != len(self.transforms):
# raise MissingTargetError(
# "The on_input settings are unclear. Please make sure all transforms "
# "either have specified individually or globally a target or there is "
# "no on_input usage."
# )

# similar to above we check that the target parameter is among the actual function parameters
if next(iter(mapping)) not in param_names:
raise base.InvalidDecoratorException(
f"Function: {fn.__name__} with parameters {param_names} does not a have "
f"dependency {next(iter(mapping))}. @pipe_input requires the parameter "
f"names to match the function parameters. Thus it might not be compatible "
f"with some other decorators."
)
elif not self._explicit_global_targets():
# If there are local targets but no global target, make sure every transform has
# a target so partially targeted chains do not silently fall back to legacy behavior.
transform_set = []
for param in mapping:
transform_set.extend(mapping[param])
if len(set(transform_set)) != len(self.transforms):
raise base.InvalidDecoratorException(
"The on_input settings are unclear. Please make sure all transforms "
"either have a target specified individually or globally, or that no "
"on_input targets are used."
)

for target_parameter in mapping:
if target_parameter not in param_names:
raise base.InvalidDecoratorException(
f"Function: {fn.__name__} with parameters {param_names} does not a have "
f"dependency {target_parameter}. @pipe_input requires the parameter "
f"names to match the function parameters. Thus it might not be compatible "
f"with some other decorators."
)

return mapping

Expand Down
Loading