Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 25 additions & 4 deletions iron/common/compilation/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,18 +56,39 @@

@dataclass
class DesignGenerator:
"""Lazy callable that imports source_path and calls fn_name(*args, **kwargs), returning MLIR as a string."""
"""Lazy callable that imports *source_path* and calls *fn_name*, returning MLIR.

On first call the loaded module is registered in ``sys.modules`` under
its fully-qualified package name (e.g. ``iron.operators.axpy.design``) so
that subsequent code can import it through the standard Python import
mechanism.
"""

source_path: Path
fn_name: str
args: tuple = ()
kwargs: dict[str, Any] = field(default_factory=dict)

@property
def module_name(self) -> str:
"""Python module path derived from *source_path*.

Walks up from *source_path* until the ``iron`` package root is found,
then builds a dotted path relative to it.
"""
parts = list(self.source_path.resolve().parts)
try:
idx = parts.index("iron")
except ValueError:
return self.source_path.stem
rel_parts = parts[idx:-1] + [self.source_path.stem]
return ".".join(rel_parts)

def __call__(self) -> str:
spec = importlib.util.spec_from_file_location(
self.source_path.name, self.source_path
)
module_name = self.module_name
spec = importlib.util.spec_from_file_location(module_name, self.source_path)
module = importlib.util.module_from_spec(spec)
sys.modules[module_name] = module
spec.loader.exec_module(module)
return str(getattr(module, self.fn_name)(*self.args, **self.kwargs))

Expand Down
13 changes: 6 additions & 7 deletions iron/common/compilation/fusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,7 @@
from __future__ import annotations

import numpy as np
import importlib.util
import importlib
from functools import partial
from pathlib import Path
from aie import ir
Expand Down Expand Up @@ -74,18 +74,17 @@ def extract_runtime_sequence_arg_types(dev_op: Any) -> list[Any]:
def get_child_mlir_module(mlir_artifact: PythonGeneratedMLIRArtifact) -> Any:
"""Extract MLIR module from a PythonGeneratedMLIRArtifact.

Uses the artifact's DesignGenerator to dynamically import the design
module and call the callback, returning the raw (non-stringified) MLIR
module object for further inspection by the fusion pass.
Imports the design module via the standard Python import mechanism
(relying on it having been registered in ``sys.modules`` by
``DesignGenerator.__call__``) and calls the callback, returning the
raw (non-stringified) MLIR module object.
"""
if not isinstance(mlir_artifact, PythonGeneratedMLIRArtifact):
raise TypeError(
f"Expected PythonGeneratedMLIRArtifact, got {type(mlir_artifact).__name__}"
)
gen = mlir_artifact.generator
spec = importlib.util.spec_from_file_location(gen.source_path.name, gen.source_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
module = importlib.import_module(gen.module_name)
callback_function = getattr(module, gen.fn_name)
return callback_function(*gen.args, **gen.kwargs)

Expand Down