Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
28 commits
Select commit Hold shift + click to select a range
aaf5505
Initial integration of hill climbing
Jake248Newman Jun 19, 2026
1661dcf
Discovery entry point
Jake248Newman Jun 19, 2026
078e93c
Update causal_testing/__main__.py
Jake248Newman Jun 22, 2026
6f9626b
rustworkx for deterministic remove_cycles
Jake248Newman Jun 22, 2026
638c656
Merge branch 'main' of github.com:Jake248Newman/CausalTestingFramework
Jake248Newman Jun 22, 2026
cb0f5e6
Update causal_testing/main.py
Jake248Newman Jun 23, 2026
8538083
Initial testing of discovery and CLI
Jake248Newman Jun 23, 2026
05c9e9e
Score and tier based fitness
Jake248Newman Jun 24, 2026
114feaf
Endpoint added max iterations
Jake248Newman Jun 29, 2026
168c788
Added "ignore_cycles" parameter to "add_nodes_from"
jmafoster1 Jul 1, 2026
9db65b6
Max iterations type error fix
Jake248Newman Jul 1, 2026
f0060a3
Merge branch 'main' of github.com:Jake248Newman/CausalTestingFramework
Jake248Newman Jul 1, 2026
76be77f
Switched to using nx_pydot instead of nx_agraph to remove dependency …
jmafoster1 Jul 1, 2026
91448fc
Merge branch 'main' of github.com:Jake248Newman/CausalTestingFramewor…
jmafoster1 Jul 1, 2026
ebfd36c
Updated LR estimator documentation since 95% confidence isn't hardcoded
jmafoster1 Jul 1, 2026
2dc409a
Now supports categorical data
jmafoster1 Jul 1, 2026
040c11a
Colouring pass/fail/error edges in output
jmafoster1 Jul 1, 2026
3da2903
Adding in dashed lines for failed independences, and now supports a s…
jmafoster1 Jul 2, 2026
9405657
Fixed query
jmafoster1 Jul 2, 2026
98ab1c6
Regex support in dot files
Jake248Newman Jul 2, 2026
3028952
Merge branch 'main' of github.com:Jake248Newman/CausalTestingFramework
Jake248Newman Jul 2, 2026
6147593
Fixed multiple dataframe concatenation
jmafoster1 Jul 2, 2026
1e03a19
Renamed "normalised_counts"
jmafoster1 Jul 2, 2026
b6681a6
Merge branch 'main' of github.com:Jake248Newman/CausalTestingFramewor…
jmafoster1 Jul 2, 2026
4d45444
Causal Discovery doc page
Jake248Newman Jul 3, 2026
58ffda2
Merge branch 'main' of github.com:Jake248Newman/CausalTestingFramework
Jake248Newman Jul 3, 2026
c95e9aa
Causal Discovery entry points
jmafoster1 Jul 3, 2026
fb7c58e
Merge branch 'main' of github.com:Jake248Newman/CausalTestingFramewor…
jmafoster1 Jul 3, 2026
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
169 changes: 105 additions & 64 deletions causal_testing/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,10 @@
import logging
import os
import tempfile
import pandas as pd
from pathlib import Path
from importlib.metadata import entry_points
import networkx as nx

from causal_testing.testing.metamorphic_relation import generate_causal_tests

Expand All @@ -20,74 +23,112 @@ def main() -> None:

# Parse arguments
args = parse_args()

if args.command == Command.GENERATE:
logging.info("Generating causal tests")
generate_causal_tests(
args.dag_path,
args.output,
args.ignore_cycles,
args.threads,
effect_type=args.effect_type,
estimate_type=args.estimate_type,
estimator=args.estimator,
skip=False,
)
logging.info("Causal test generation completed successfully")
return

# Setup logging
setup_logging(args.verbose)

# Create paths object
paths = CausalTestingPaths(
dag_path=args.dag_path,
data_paths=args.data_paths,
test_config_path=args.test_config,
output_path=args.output,
)

# Create and setup framework
framework = CausalTestingFramework(paths, ignore_cycles=args.ignore_cycles, query=args.query)
framework.setup()

# Load and run tests
framework.load_tests()

if args.batch_size > 0:
logging.info(f"Running tests in batches of size {args.batch_size}")
with tempfile.TemporaryDirectory() as tmpdir:
output_files = []
for i, results in enumerate(
framework.run_tests_in_batches(
batch_size=args.batch_size,
silent=args.silent,
adequacy=args.adequacy,
bootstrap_size=args.bootstrap_size,
match args.command:
case Command.GENERATE:
logging.info("Generating causal tests")
generate_causal_tests(
args.dag_path,
args.output,
args.ignore_cycles,
args.threads,
effect_type=args.effect_type,
estimate_type=args.estimate_type,
estimator=args.estimator,
skip=False,
)
logging.info("Causal test generation completed successfully.")

case Command.DISCOVER:
discover_map = {ff.name: ff for ff in entry_points(group="discovery")}
if args.technique not in discover_map:
raise ValueError(
f"Unsupported technique {args.technique}. Supported: {sorted(discover_map)}. "
"If you have implemented a custom technique, you will need to add this to your entrypoints via your "
"pyproject.toml file."
)
):
temp_file_path = os.path.join(tmpdir, f"output_{i}.json")
framework.save_results(results, temp_file_path)
output_files.append(temp_file_path)
del results

# Now stitch the results together from the temporary files
all_results = []
for file_path in output_files:
with open(file_path, "r", encoding="utf-8") as f:
all_results.extend(json.load(f))

output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)

with open(args.output, "w", encoding="utf-8") as f:
json.dump(all_results, f, indent=4)
else:
logging.info("Running tests in regular mode")
results = framework.run_tests(silent=args.silent, adequacy=args.adequacy, bootstrap_size=args.bootstrap_size)
framework.save_results(results)

logging.info("Causal testing completed successfully.")
kwargs = {}
for argument in args.technique_kwargs:
split = argument.split("=")
if len(split) != 2:
raise ValueError(f"Malformed argument {argument}. Should be specified as `arg_name=arg_value`")
kwargs[split[0]] = split[1]

logging.info("Discovering causal structure")
# Need to reset index to allow for multiple files having the same index (i.e. starting at zero).
# Otherwise you end up with duplicate indices, which causes problems further down the line
df = pd.concat([pd.read_csv(path) for path in args.data_paths]).reset_index()
if args.variables:
df = df[args.variables]

discover_class = discover_map[args.technique].load()
discover = discover_class(
df=df,
exclude_edges=(
list(nx.nx_pydot.read_dot(args.exclude_edges).edges()) if args.exclude_edges is not None else []
),
include_edges=(
list(nx.nx_pydot.read_dot(args.include_edges).edges()) if args.include_edges is not None else []
),
**kwargs,
)
evolved_dag = discover.discover()
discover.write_dot(evolved_dag, args.output)
logging.info("Causal structure discovery completed successfully.")
case Command.TEST:
# Create paths object
paths = CausalTestingPaths(
dag_path=args.dag_path,
data_paths=args.data_paths,
test_config_path=args.test_config,
output_path=args.output,
)

# Create and setup framework
framework = CausalTestingFramework(paths, ignore_cycles=args.ignore_cycles, query=args.query)
framework.setup()

# Load and run tests
framework.load_tests()

if args.batch_size > 0:
logging.info(f"Running tests in batches of size {args.batch_size}")
with tempfile.TemporaryDirectory() as tmpdir:
output_files = []
for i, results in enumerate(
framework.run_tests_in_batches(
batch_size=args.batch_size,
silent=args.silent,
adequacy=args.adequacy,
bootstrap_size=args.bootstrap_size,
)
):
temp_file_path = os.path.join(tmpdir, f"output_{i}.json")
framework.save_results(results, temp_file_path)
output_files.append(temp_file_path)
del results

# Now stitch the results together from the temporary files
all_results = []
for file_path in output_files:
with open(file_path, "r", encoding="utf-8") as f:
all_results.extend(json.load(f))

output_path = Path(args.output)
output_path.parent.mkdir(parents=True, exist_ok=True)

with open(args.output, "w", encoding="utf-8") as f:
json.dump(all_results, f, indent=4)
else:
logging.info("Running tests in regular mode")
results = framework.run_tests(
silent=args.silent, adequacy=args.adequacy, bootstrap_size=args.bootstrap_size
)
framework.save_results(results)

logging.info("Causal testing completed successfully.")


if __name__ == "__main__":
Expand Down
Empty file.
Loading