From 85e5ecff51a814c37ad7d20c7a61702b62abe617 Mon Sep 17 00:00:00 2001 From: Claude Date: Thu, 12 Mar 2026 17:07:51 +0000 Subject: [PATCH] camera-ready cleanup: fix hardcoded paths, remove TODOs, replace prints with logging, auto-version - Replace hardcoded absolute paths in CLI tools (merge_csv_groups, unmerge_logs) with default=None and require --root or --root-glob explicitly - Replace hardcoded checkpoint base path in eval scripts with LOGS_ROOT env var (defaults to ./logs/train), supporting clean override without editing scripts - Remove two stale TODO comments (gfn_utils.py, reward.py) - Delete debug utility script buffer_sample_example.py - Replace bare print() calls in gfn.py and losses.py with logging.warning() - Update setup.py: version auto-generated as 0.1.0+ https://claude.ai/code/session_011zEp1N46irzpZqjM15vmTP --- chemgfn/cli/merge_csv_groups.py | 7 +- chemgfn/cli/unmerge_logs.py | 7 +- chemgfn/models/gfn.py | 9 +- chemgfn/models/losses.py | 5 +- chemgfn/models/reward.py | 1 - chemgfn/utils/buffer_sample_example.py | 190 ------------------------- chemgfn/utils/gfn_utils.py | 1 - scripts/run_eval_all.sh | 34 +++-- scripts/run_eval_expr24_all.sh | 40 +++--- setup.py | 27 +++- 10 files changed, 85 insertions(+), 236 deletions(-) delete mode 100644 chemgfn/utils/buffer_sample_example.py diff --git a/chemgfn/cli/merge_csv_groups.py b/chemgfn/cli/merge_csv_groups.py index 801b811..97df3f5 100644 --- a/chemgfn/cli/merge_csv_groups.py +++ b/chemgfn/cli/merge_csv_groups.py @@ -183,7 +183,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--root", type=str, - default="/data1/xw3763/project/gflow/ChemGFN/logs/train", + default=None, help="Root directory to search recursively.", ) parser.add_argument( @@ -272,8 +272,11 @@ def main() -> None: if not roots: print(f"[info] root_glob matched nothing: {args.root_glob}") return - else: + elif args.root is not None: roots = [Path(args.root).expanduser().resolve()] + else: + print("[error] provide --root or --root-glob") + raise SystemExit(1) for root in roots: print(f"[root] processing {root}") diff --git a/chemgfn/cli/unmerge_logs.py b/chemgfn/cli/unmerge_logs.py index 5e8c8f5..991497a 100644 --- a/chemgfn/cli/unmerge_logs.py +++ b/chemgfn/cli/unmerge_logs.py @@ -126,7 +126,7 @@ def parse_args() -> argparse.Namespace: parser.add_argument( "--root", type=str, - default="/data1/xw3763/project/gflow/ChemGFN/logs/train", + default=None, help="Root directory to search recursively.", ) parser.add_argument( @@ -186,8 +186,11 @@ def main() -> None: if not roots: print(f"[info] root_glob matched nothing: {args.root_glob}") return - else: + elif args.root is not None: roots = [Path(args.root).expanduser().resolve()] + else: + print("[error] provide --root or --root-glob") + raise SystemExit(1) for root in roots: print(f"[root] expanding {root}") diff --git a/chemgfn/models/gfn.py b/chemgfn/models/gfn.py index adc2d75..e326aba 100644 --- a/chemgfn/models/gfn.py +++ b/chemgfn/models/gfn.py @@ -3,6 +3,7 @@ from __future__ import annotations import json +import logging import os import random import sys @@ -52,6 +53,8 @@ from chemgfn.utils.replay_buffer import ReplayBuffer from chemgfn.utils.schedulers import Scheduler +log = logging.getLogger(__name__) + # Re-export Scheduler for backward compatibility with config files __all__ = ["ChemGFNModule", "Scheduler"] @@ -309,7 +312,7 @@ def __init__( self.net_frozen, mode="max-autotune", fullgraph=False ) except Exception as exc: # pragma: no cover - defensive logging - print(f"torch.compile failed, continuing without compilation: {exc}") + log.warning("torch.compile failed, continuing without compilation: %s", exc) # phi cache self._pv_probe_cache = None @@ -2254,7 +2257,7 @@ def _load_token_masks(self): return prepare_token_mask(self.tokenizer, tokens_path) if tokens_path: - print(f"Legal tokens file not found: {tokens_path}") + log.warning("Legal tokens file not found: %s", tokens_path) # Fallback: support explicit illegal token strings (common for text tasks). illegal_tokens = getattr(self.constraint_config, "illegal_tokens", None) @@ -2283,7 +2286,7 @@ def _build_pre_grammar_processor(self, parsed_grammar): return None if processor_type == "general": if self.grammar is None: - print("Grammar parsing failed with current tokenizer, disable general processor") + log.warning("Grammar parsing failed with current tokenizer, disabling general processor") return None return GrammarConstrainedLogitsProcessor(self.grammar) diff --git a/chemgfn/models/losses.py b/chemgfn/models/losses.py index ba7b7ab..d890295 100644 --- a/chemgfn/models/losses.py +++ b/chemgfn/models/losses.py @@ -5,10 +5,13 @@ including SubTrajectory Balance (SubTB) losses with various enhancements. """ +import logging from abc import ABC, abstractmethod from typing import Any, Dict, Literal, Optional, Tuple import torch + +log = logging.getLogger(__name__) import torch.nn as nn import torch.nn.functional as F @@ -602,7 +605,7 @@ def __init__( if not (self.gamma < 1.0): raise ValueError(f"gamma must be < 1.0, got gamma={self.gamma}") if self.k_min < 1: - print(f"k_min must be >= 1, got k_min={self.k_min}, setting to 1 automatically") + log.warning("k_min must be >= 1, got k_min=%d, setting to 1 automatically", self.k_min) self.k_min = 1 self.logZ = torch.nn.Parameter(torch.tensor([float(init_logZ)], dtype=torch.float32)) diff --git a/chemgfn/models/reward.py b/chemgfn/models/reward.py index b6dc16f..7fb6225 100644 --- a/chemgfn/models/reward.py +++ b/chemgfn/models/reward.py @@ -139,7 +139,6 @@ def score_fast( logits = logits.detach()[:, skip_first - 1 :] - # TODO: remove all penalty from the logits if invalid_vocab_mask is not None: logits = logits.clone() logits[:, :, invalid_vocab_mask] += illegal_vocab_penalty diff --git a/chemgfn/utils/buffer_sample_example.py b/chemgfn/utils/buffer_sample_example.py deleted file mode 100644 index 1c88ef0..0000000 --- a/chemgfn/utils/buffer_sample_example.py +++ /dev/null @@ -1,190 +0,0 @@ -#!/usr/bin/env python -""" -Buffer Sampling Usage Example - -This script demonstrates how to create and use buffer samples. -""" - -import torch -from transformers import AutoTokenizer - - -def create_buffer_samples( - tokenizer_name: str = "meta-llama/Meta-Llama-3-8B-Instruct", - num_samples: int = 100, - seq_len: int = 10, - output_path: str = "buffer_samples.pt", -): - """ - Create example buffer samples file. - - Args: - tokenizer_name: Tokenizer name - num_samples: Number of samples to generate - seq_len: Length of each sample - output_path: Output file path - """ - print(f"Creating buffer samples...") - print(f" Tokenizer: {tokenizer_name}") - print(f" Number of samples: {num_samples}") - print(f" Sequence length: {seq_len}") - - # Load tokenizer - tokenizer = AutoTokenizer.from_pretrained(tokenizer_name) - vocab_size = len(tokenizer) - - # Generate random token IDs (in practice, should be meaningful sequences) - buffer_samples = torch.randint(0, vocab_size, (num_samples, seq_len), dtype=torch.long) - - # Ensure each sequence ends with EOS token - buffer_samples[:, -1] = tokenizer.eos_token_id - - # Save - torch.save(buffer_samples, output_path) - - print(f"\nOK: Buffer samples saved to: {output_path}") - print(f" Shape: {buffer_samples.shape}") - print(f" Dtype: {buffer_samples.dtype}") - print(f"\nExample sample:") - print(f" Token IDs: {buffer_samples[0].tolist()}") - print(f" Decoded: {tokenizer.decode(buffer_samples[0], skip_special_tokens=False)}") - - -def verify_buffer_samples(buffer_path: str): - """ - Verify that the buffer samples file format is correct. - - Args: - buffer_path: Buffer file path - """ - print(f"\nVerifying buffer samples: {buffer_path}") - - try: - # Try to load - buffer = torch.load(buffer_path) - - # Check type - if isinstance(buffer, torch.Tensor): - print(f"OK: Type: Tensor") - print(f"OK: Shape: {buffer.shape}") - print(f"OK: Dtype: {buffer.dtype}") - print(f"OK: Number of elements: {buffer.numel()}") - - if buffer.numel() == 0: - print("WARN: Buffer is empty!") - return False - - elif isinstance(buffer, list): - print(f"OK: Type: List") - print(f"OK: Length: {len(buffer)}") - - if len(buffer) == 0: - print("WARN: Buffer is empty!") - return False - - print(f"OK: First element type: {type(buffer[0])}") - if isinstance(buffer[0], torch.Tensor): - print(f"OK: First element shape: {buffer[0].shape}") - else: - print(f"ERROR: Unknown type: {type(buffer)}") - return False - - print(f"\nOK: Buffer samples verification passed!") - return True - - except FileNotFoundError: - print(f"ERROR: File not found: {buffer_path}") - return False - except Exception as e: - print(f"ERROR: Loading failed: {e}") - return False - - -def create_empty_buffer(output_path: str = "empty_buffer.pt"): - """ - Create an empty buffer file for testing automatic detection. - - Args: - output_path: Output file path - """ - print(f"\nCreating empty buffer: {output_path}") - - # Create empty tensor - empty_buffer = torch.tensor([], dtype=torch.long) - - # Save - torch.save(empty_buffer, output_path) - - print(f"OK: Empty buffer created: {output_path}") - print( - f" This file will be automatically detected as invalid, buffer sampling will be disabled" - ) - - -if __name__ == "__main__": - import argparse - - parser = argparse.ArgumentParser(description="Buffer Sampling Tool") - parser.add_argument( - "--action", - type=str, - choices=["create", "verify", "create-empty"], - default="create", - help="Action to perform", - ) - parser.add_argument( - "--output", - type=str, - default="buffer_samples.pt", - help="Output file path", - ) - parser.add_argument( - "--num-samples", - type=int, - default=100, - help="Number of samples", - ) - parser.add_argument( - "--seq-len", - type=int, - default=10, - help="Sequence length", - ) - parser.add_argument( - "--tokenizer", - type=str, - default="meta-llama/Meta-Llama-3-8B-Instruct", - help="Tokenizer name", - ) - - args = parser.parse_args() - - if args.action == "create": - create_buffer_samples( - tokenizer_name=args.tokenizer, - num_samples=args.num_samples, - seq_len=args.seq_len, - output_path=args.output, - ) - elif args.action == "verify": - verify_buffer_samples(args.output) - elif args.action == "create-empty": - create_empty_buffer(args.output) - - print("\n" + "=" * 60) - print("Usage examples:") - print("=" * 60) - print("\n1. Create buffer samples:") - print( - " python buffer_sample_example.py --action create --output my_buffer.pt --num-samples 1000" - ) - print("\n2. Verify buffer samples:") - print(" python buffer_sample_example.py --action verify --output my_buffer.pt") - print("\n3. Create empty buffer (test automatic detection):") - print(" python buffer_sample_example.py --action create-empty --output empty.pt") - print("\n4. Use in training:") - print( - " python chemgfn/train.py experiment=SMILES_basic/SMILES_cfg_TB " - "data.buffer_sample_path=my_buffer.pt" - ) - print() diff --git a/chemgfn/utils/gfn_utils.py b/chemgfn/utils/gfn_utils.py index 05de5b7..5259f80 100644 --- a/chemgfn/utils/gfn_utils.py +++ b/chemgfn/utils/gfn_utils.py @@ -324,7 +324,6 @@ def generate_and_return_termination_logprob( ) else: token_ids = action_seq[:, idx].unsqueeze(-1).to(device) - # TODO: simple mask, no-eos before max_len; scores = logits.clone().detach() scores = default_processor(state, scores) results = logits_processor(state, scores, disable_grammar=disable_grammar) diff --git a/scripts/run_eval_all.sh b/scripts/run_eval_all.sh index ef2099b..fd0c7be 100644 --- a/scripts/run_eval_all.sh +++ b/scripts/run_eval_all.sh @@ -2,30 +2,34 @@ # run all evals distributed across selected CUDA devices; keep going if one fails set -uo pipefail -readarray -t cmds <<'EOF' +# Root directory containing training logs. Override via environment variable: +# LOGS_ROOT=/your/logs/path bash scripts/run_eval_all.sh +LOGS_ROOT="${LOGS_ROOT:-./logs/train}" + +readarray -t cmds < str: + base = "0.1.0" + try: + commit = ( + subprocess.check_output( + ["git", "rev-parse", "--short", "HEAD"], + cwd=os.path.dirname(os.path.abspath(__file__)), + stderr=subprocess.DEVNULL, + ) + .decode() + .strip() + ) + return f"{base}+{commit}" + except Exception: + return base + + setup( name="chemgfn", - version="0.0.1", - description="Gflownet for chemistry", + version=get_version(), + description="GFlowNet with LLMs for chemistry and arithmetic generation", author="", author_email="", - url="https://github.com/user/project", + url="https://github.com/ComDec/ChemGFN", install_requires=[ "torch", "torchvision",