Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions chemgfn/cli/merge_csv_groups.py
Original file line number Diff line number Diff line change
Expand Up @@ -183,7 +183,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--root",
type=str,
default="/data1/xw3763/project/gflow/ChemGFN/logs/train",
default=None,
help="Root directory to search recursively.",
)
parser.add_argument(
Expand Down Expand Up @@ -272,8 +272,11 @@ def main() -> None:
if not roots:
print(f"[info] root_glob matched nothing: {args.root_glob}")
return
else:
elif args.root is not None:
roots = [Path(args.root).expanduser().resolve()]
else:
print("[error] provide --root or --root-glob")
raise SystemExit(1)

for root in roots:
print(f"[root] processing {root}")
Expand Down
7 changes: 5 additions & 2 deletions chemgfn/cli/unmerge_logs.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,7 +126,7 @@ def parse_args() -> argparse.Namespace:
parser.add_argument(
"--root",
type=str,
default="/data1/xw3763/project/gflow/ChemGFN/logs/train",
default=None,
help="Root directory to search recursively.",
)
parser.add_argument(
Expand Down Expand Up @@ -186,8 +186,11 @@ def main() -> None:
if not roots:
print(f"[info] root_glob matched nothing: {args.root_glob}")
return
else:
elif args.root is not None:
roots = [Path(args.root).expanduser().resolve()]
else:
print("[error] provide --root or --root-glob")
raise SystemExit(1)

for root in roots:
print(f"[root] expanding {root}")
Expand Down
9 changes: 6 additions & 3 deletions chemgfn/models/gfn.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
from __future__ import annotations

import json
import logging
import os
import random
import sys
Expand Down Expand Up @@ -52,6 +53,8 @@
from chemgfn.utils.replay_buffer import ReplayBuffer
from chemgfn.utils.schedulers import Scheduler

log = logging.getLogger(__name__)

# Re-export Scheduler for backward compatibility with config files
__all__ = ["ChemGFNModule", "Scheduler"]

Expand Down Expand Up @@ -309,7 +312,7 @@ def __init__(
self.net_frozen, mode="max-autotune", fullgraph=False
)
except Exception as exc: # pragma: no cover - defensive logging
print(f"torch.compile failed, continuing without compilation: {exc}")
log.warning("torch.compile failed, continuing without compilation: %s", exc)

# phi cache
self._pv_probe_cache = None
Expand Down Expand Up @@ -2254,7 +2257,7 @@ def _load_token_masks(self):
return prepare_token_mask(self.tokenizer, tokens_path)

if tokens_path:
print(f"Legal tokens file not found: {tokens_path}")
log.warning("Legal tokens file not found: %s", tokens_path)

# Fallback: support explicit illegal token strings (common for text tasks).
illegal_tokens = getattr(self.constraint_config, "illegal_tokens", None)
Expand Down Expand Up @@ -2283,7 +2286,7 @@ def _build_pre_grammar_processor(self, parsed_grammar):
return None
if processor_type == "general":
if self.grammar is None:
print("Grammar parsing failed with current tokenizer, disable general processor")
log.warning("Grammar parsing failed with current tokenizer, disabling general processor")
return None
return GrammarConstrainedLogitsProcessor(self.grammar)

Expand Down
5 changes: 4 additions & 1 deletion chemgfn/models/losses.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
including SubTrajectory Balance (SubTB) losses with various enhancements.
"""

import logging
from abc import ABC, abstractmethod
from typing import Any, Dict, Literal, Optional, Tuple

import torch

log = logging.getLogger(__name__)
import torch.nn as nn
import torch.nn.functional as F

Expand Down Expand Up @@ -602,7 +605,7 @@ def __init__(
if not (self.gamma < 1.0):
raise ValueError(f"gamma must be < 1.0, got gamma={self.gamma}")
if self.k_min < 1:
print(f"k_min must be >= 1, got k_min={self.k_min}, setting to 1 automatically")
log.warning("k_min must be >= 1, got k_min=%d, setting to 1 automatically", self.k_min)
self.k_min = 1

self.logZ = torch.nn.Parameter(torch.tensor([float(init_logZ)], dtype=torch.float32))
Expand Down
1 change: 0 additions & 1 deletion chemgfn/models/reward.py
Original file line number Diff line number Diff line change
Expand Up @@ -139,7 +139,6 @@ def score_fast(

logits = logits.detach()[:, skip_first - 1 :]

# TODO: remove all penalty from the logits
if invalid_vocab_mask is not None:
logits = logits.clone()
logits[:, :, invalid_vocab_mask] += illegal_vocab_penalty
Expand Down
190 changes: 0 additions & 190 deletions chemgfn/utils/buffer_sample_example.py

This file was deleted.

1 change: 0 additions & 1 deletion chemgfn/utils/gfn_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -324,7 +324,6 @@ def generate_and_return_termination_logprob(
)
else:
token_ids = action_seq[:, idx].unsqueeze(-1).to(device)
# TODO: simple mask, no-eos before max_len;
scores = logits.clone().detach()
scores = default_processor(state, scores)
results = logits_processor(state, scores, disable_grammar=disable_grammar)
Expand Down
34 changes: 19 additions & 15 deletions scripts/run_eval_all.sh
Original file line number Diff line number Diff line change
Expand Up @@ -2,30 +2,34 @@
# run all evals distributed across selected CUDA devices; keep going if one fails
set -uo pipefail

readarray -t cmds <<'EOF'
# Root directory containing training logs. Override via environment variable:
# LOGS_ROOT=/your/logs/path bash scripts/run_eval_all.sh
LOGS_ROOT="${LOGS_ROOT:-./logs/train}"

readarray -t cmds <<EOF
# baseline
python chemgfn/eval.py experiment="SMILES_basic/SMILES_cfg_TB" +trainer.limit_test_batches=100 ckpt_path="/data1/xw3763/project/gflow/ChemGFN/logs/train/smiles_CFG_TB/train/runs/2025-12-31_05-22-27/checkpoints/last.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_basic/SMILES_cfg_no_TB" +trainer.limit_test_batches=100 ckpt_path="/data1/xw3763/project/gflow/ChemGFN/logs/train/smiles_CFG_TB_no_CFG/train/runs/2025-12-27_22-42-15/checkpoints/last.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_SubM/SMILES_cfg_TB_subM_replay_add_len_func" +trainer.limit_test_batches=100 ckpt_path="/data1/xw3763/project/gflow/ChemGFN/logs/train/smiles_CFG_TB_subM_replay_add_len_func/train/runs/2026-01-06_13-17-11/checkpoints/last.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_basic/SMILES_cfg_subTB" +trainer.limit_test_batches=100 ckpt_path="/data1/xw3763/project/gflow/ChemGFN/logs/train/smiles_CFG_subTB/train/runs/2026-01-03_14-02-50/checkpoints/last.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_SubM/SMILES_cfg_SubTB_subM_full" +trainer.limit_test_batches=100 ckpt_path="/data1/xw3763/project/gflow/ChemGFN/logs/train/smiles_CFG_SubTB_subM_full/train/runs/2026-01-20_13-00-48/checkpoints/last.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_basic/SMILES_cfg_TB_wo_ref" +trainer.limit_test_batches=100 ckpt_path="/data1/xw3763/project/gflow/ChemGFN/logs/train/smiles_CFG_TB_wo_ref/train/runs/2026-01-18_04-46-17/checkpoints/last.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_basic/SMILES_cfg_TB" +trainer.limit_test_batches=100 ckpt_path="${LOGS_ROOT}/smiles_CFG_TB/train/runs/2025-12-31_05-22-27/checkpoints/last.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_basic/SMILES_cfg_no_TB" +trainer.limit_test_batches=100 ckpt_path="${LOGS_ROOT}/smiles_CFG_TB_no_CFG/train/runs/2025-12-27_22-42-15/checkpoints/last.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_SubM/SMILES_cfg_TB_subM_replay_add_len_func" +trainer.limit_test_batches=100 ckpt_path="${LOGS_ROOT}/smiles_CFG_TB_subM_replay_add_len_func/train/runs/2026-01-06_13-17-11/checkpoints/last.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_basic/SMILES_cfg_subTB" +trainer.limit_test_batches=100 ckpt_path="${LOGS_ROOT}/smiles_CFG_subTB/train/runs/2026-01-03_14-02-50/checkpoints/last.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_SubM/SMILES_cfg_SubTB_subM_full" +trainer.limit_test_batches=100 ckpt_path="${LOGS_ROOT}/smiles_CFG_SubTB_subM_full/train/runs/2026-01-20_13-00-48/checkpoints/last.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_basic/SMILES_cfg_TB_wo_ref" +trainer.limit_test_batches=100 ckpt_path="${LOGS_ROOT}/smiles_CFG_TB_wo_ref/train/runs/2026-01-18_04-46-17/checkpoints/last.ckpt" test_repeats=3

# RapTB
python chemgfn/eval.py experiment="SMILES_RapTB/SMILES_cfg_RapTB_v2_kmin_5_to_2_mix_fix" +trainer.limit_test_batches=100 ckpt_path="/data1/xw3763/project/gflow/ChemGFN/logs/train/smiles_RapTB_v2_kmin_5_to_2_mix_fix_softmax_overflow/train/runs/2026-01-10_06-13-37/checkpoints/last_2.47.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_RapTB/SMILES_cfg_RapTB_v2_kmin_5_to_2_mix_fix_subM" +trainer.limit_test_batches=100 ckpt_path="/data1/xw3763/project/gflow/ChemGFN/logs/train/smiles_RapTB_v2_kmin_5_to_2_mix_fix_softmax_overflow_subM/train/runs/2026-01-10_06-26-26/checkpoints/epoch_009_diversity_2.6636_best.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_RapTB/SMILES_cfg_RapTB_v2_kmin_5_to_2_mix_fix" +trainer.limit_test_batches=100 ckpt_path="${LOGS_ROOT}/smiles_RapTB_v2_kmin_5_to_2_mix_fix_softmax_overflow/train/runs/2026-01-10_06-13-37/checkpoints/last_2.47.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_RapTB/SMILES_cfg_RapTB_v2_kmin_5_to_2_mix_fix_subM" +trainer.limit_test_batches=100 ckpt_path="${LOGS_ROOT}/smiles_RapTB_v2_kmin_5_to_2_mix_fix_softmax_overflow_subM/train/runs/2026-01-10_06-26-26/checkpoints/epoch_009_diversity_2.6636_best.ckpt" test_repeats=3

# RapTB ablation
python chemgfn/eval.py experiment="SMILES_RapTB/SMILES_cfg_RapTB_v2_kmin_5_to_2_max_only" +trainer.limit_test_batches=100 ckpt_path="/data1/xw3763/project/gflow/ChemGFN/logs/train/smiles_RapTB_v2_kmin_5_to_2_max_only/train/runs/2026-01-19_05-30-27/checkpoints/epoch_019_diversity_2.3899.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_RapTB/SMILES_cfg_RapTB_v2_kmin_5_to_2_soft_only" +trainer.limit_test_batches=100 ckpt_path="/data1/xw3763/project/gflow/ChemGFN/logs/train/smiles_RapTB_v2_kmin_5_to_2_soft_only/train/runs/2026-01-19_05-30-58/checkpoints/epoch_019_diversity_2.0664.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_RapTB/SMILES_cfg_RapTB_v2_kmin_5_to_2_max_only" +trainer.limit_test_batches=100 ckpt_path="${LOGS_ROOT}/smiles_RapTB_v2_kmin_5_to_2_max_only/train/runs/2026-01-19_05-30-27/checkpoints/epoch_019_diversity_2.3899.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_RapTB/SMILES_cfg_RapTB_v2_kmin_5_to_2_soft_only" +trainer.limit_test_batches=100 ckpt_path="${LOGS_ROOT}/smiles_RapTB_v2_kmin_5_to_2_soft_only/train/runs/2026-01-19_05-30-58/checkpoints/epoch_019_diversity_2.0664.ckpt" test_repeats=3

# length 15
python chemgfn/eval.py experiment="SMILES_Length/SMILES_cfg_TB_len_15" +trainer.limit_test_batches=100 ckpt_path="/data1/xw3763/project/gflow/ChemGFN/logs/train/smiles_CFG_TB_len_15/train/runs/2026-01-07_03-26-51/checkpoints/last.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_Length/SMILES_cfg_subTB_len_15" +trainer.limit_test_batches=100 ckpt_path="/data1/xw3763/project/gflow/ChemGFN/logs/train/smiles_CFG_subTB_len_15/train/runs/2026-01-12_22-46-25/checkpoints/last.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_Length/SMILES_cfg_TB_len_15" +trainer.limit_test_batches=100 ckpt_path="${LOGS_ROOT}/smiles_CFG_TB_len_15/train/runs/2026-01-07_03-26-51/checkpoints/last.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_Length/SMILES_cfg_subTB_len_15" +trainer.limit_test_batches=100 ckpt_path="${LOGS_ROOT}/smiles_CFG_subTB_len_15/train/runs/2026-01-12_22-46-25/checkpoints/last.ckpt" test_repeats=3

# Length 15 RapTB
python chemgfn/eval.py experiment="SMILES_Length/SMILES_cfg_RapTB_v2_kmin_12_to_8_mix_fix_len15" +trainer.limit_test_batches=100 ckpt_path="/data1/xw3763/project/gflow/ChemGFN/logs/train/smiles_RapTB_v2_kmin_12_to_8_mix_fix_len15/train/runs/2026-01-12_00-51-08/checkpoints/last_1.91.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_Length/SMILES_cfg_RapTB_v2_kmin_12_to_8_mix_fix_len15_subM" +trainer.limit_test_batches=100 ckpt_path="/data1/xw3763/project/gflow/ChemGFN/logs/train/smiles_RapTB_v2_kmin_12_to_8_mix_fix_len15_subM/train/runs/2026-01-13_08-57-58/checkpoints/2.25.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_Length/SMILES_cfg_RapTB_v2_kmin_12_to_8_mix_fix_len15" +trainer.limit_test_batches=100 ckpt_path="${LOGS_ROOT}/smiles_RapTB_v2_kmin_12_to_8_mix_fix_len15/train/runs/2026-01-12_00-51-08/checkpoints/last_1.91.ckpt" test_repeats=3
python chemgfn/eval.py experiment="SMILES_Length/SMILES_cfg_RapTB_v2_kmin_12_to_8_mix_fix_len15_subM" +trainer.limit_test_batches=100 ckpt_path="${LOGS_ROOT}/smiles_RapTB_v2_kmin_12_to_8_mix_fix_len15_subM/train/runs/2026-01-13_08-57-58/checkpoints/2.25.ckpt" test_repeats=3

EOF

Expand Down
Loading