From 0c9882c845a226904ee5703ba09fa997331dcb5b Mon Sep 17 00:00:00 2001 From: furkannecatiinan Date: Mon, 22 Jun 2026 22:08:29 +0300 Subject: [PATCH 1/4] shap generators --- shap_generator_cnn.py | 307 ++++++++++++++++++++++++++++++++++ shap_generator_vit.py | 366 +++++++++++++++++++++++++++++++++++++++++ shap_generator_yolo.py | 326 ++++++++++++++++++++++++++++++++++++ 3 files changed, 999 insertions(+) create mode 100644 shap_generator_cnn.py create mode 100644 shap_generator_vit.py create mode 100644 shap_generator_yolo.py diff --git a/shap_generator_cnn.py b/shap_generator_cnn.py new file mode 100644 index 0000000..287f058 --- /dev/null +++ b/shap_generator_cnn.py @@ -0,0 +1,307 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +SHAP Data Generator +=================== +This module calculates SHAP values and saves them to disk in .npz (Numpy Zip) format. +It does not perform visualization; it only generates raw data. + +Output File Content (keys): +- 'shap_values': SHAP values in (C, H, W) format. +- 'image': Original normalized image in (C, H, W) format. +- 'prediction': int (Predicted class index) +- 'probabilities': array (Class probabilities) +""" + +""" +EXAMPLE USAGE: +-------------- + +1. Prepare and Run (Complete Workflow): +CUDA_VISIBLE_DEVICES=0 python shap_generator.py prepare-and-run \ + --model_path "/path/to/your/model.pth" \ + --fc1 128 --fc2 256 --dropout 0.2 \ + --num_background 50 \ + --explainer_path "my_explainer.pkl" \ + --input_path "/path/to/images_folder" \ + --output_dir "./output_results" \ + --cuda_selection 0 + +2. Prepare Only (Generates the background dataset and explainer object): +python shap_generator.py prepare \ + --model_path "/path/to/model.pth" \ + --explainer_path "my_explainer.pkl" + +3. Run Only (Uses an existing .pkl explainer to process images): +python shap_generator.py run \ + --explainer_path "my_explainer.pkl" \ + --input_path "/path/to/single_image.png" \ + --output_dir "./output_results" +""" + +import os +import time +import torch +import numpy as np +import cv2 +import argparse +import pickle +from glob import glob +from models import CNNModel1 +import warnings +import shap + +warnings.filterwarnings('ignore') + + +def _format_duration(seconds): + """Format seconds as Hh Mm Ss.""" + seconds = int(seconds) + h, rem = divmod(seconds, 3600) + m, s = divmod(rem, 60) + if h > 0: + return f"{h}h {m}m {s}s" + if m > 0: + return f"{m}m {s}s" + return f"{s}s" + +# ----------------------------------------------------------------------------- +# Data Saving Function +# ----------------------------------------------------------------------------- + +def save_shap_data(image_tensor, shap_values, prediction, probabilities, output_path): + """ + Saves data in compressed numpy format (.npz). + """ + # Convert tensor to numpy + if isinstance(image_tensor, torch.Tensor): + img_np = image_tensor.cpu().detach().numpy() + # Remove batch dimension: (1, 3, 300, 300) -> (3, 300, 300) + img_np = img_np[0] + else: + img_np = image_tensor + + if isinstance(shap_values, torch.Tensor): + shap_values = shap_values.cpu().detach().numpy() + + # Save data + np.savez_compressed( + output_path, + shap_values=shap_values, + image=img_np, + prediction=prediction, + probabilities=probabilities + ) + +# ----------------------------------------------------------------------------- +# Core Functions +# ----------------------------------------------------------------------------- + +class ModelWrapper(torch.nn.Module): + def __init__(self, model): + super(ModelWrapper, self).__init__() + self.model = model + + def forward(self, x): + output = self.model(x) + return torch.nn.functional.softmax(output, dim=1) + +def create_background_dataset(model, device, num_samples=50, img_size=300): + print(f"Creating background dataset with {num_samples} samples...", flush=True) + background_images = [] + for _ in range(num_samples): + # Creating a simple white background as a baseline + img = np.ones((img_size, img_size, 3), dtype=np.uint8) * 255 + img = img.astype(np.float32) / 255.0 + img = img.transpose((2, 0, 1)) + background_images.append(img) + return torch.FloatTensor(np.array(background_images)).to(device) + +def prepare_explainer(model_path, fc1, fc2, dropout, num_background, cuda_selection, save_path): + print("="*60); print("Mode: Preparing SHAP Generator"); print("="*60, flush=True) + device = f'cuda:{cuda_selection}' if torch.cuda.is_available() else 'cpu' + print(f"Device: {device}", flush=True) + + try: + model = CNNModel1(fc1, fc2, dropout).to(device) + model.load_state_dict(torch.load(model_path, map_location=device)) + model.eval() + except Exception as e: + print(f"Error loading model: {e}"); return None, None, None + + wrapped_model = ModelWrapper(model).to(device) + background_images = create_background_dataset(model, device, num_samples=num_background) + + explainer = shap.GradientExplainer(wrapped_model, background_images) + + explainer_data = { + 'explainer': explainer, 'model_path': model_path, 'fc1': fc1, + 'fc2': fc2, 'dropout': dropout, 'device': device, 'wrapped_model': wrapped_model, + } + + with open(save_path, 'wb') as f: + pickle.dump(explainer_data, f) + return explainer, wrapped_model, device + +def load_explainer(explainer_path): + if not os.path.exists(explainer_path): return None, None, None + with open(explainer_path, 'rb') as f: + data = pickle.load(f) + data['wrapped_model'].to(data['device']) + data['wrapped_model'].eval() + return data['explainer'], data['wrapped_model'], data['device'] + +def load_image(image_path, img_size=300): + img = cv2.imread(image_path) + if img is None: raise FileNotFoundError(f"Error: Could not load {image_path}") + if img.shape[:2] != (img_size, img_size): + img = cv2.resize(img, (img_size, img_size)) + + # Normalization and Channel Transpose + img_normalized = img.astype(np.float32) / 255.0 + img_normalized = img_normalized.transpose((2, 0, 1)) + return torch.FloatTensor(img_normalized).unsqueeze(0) + +def process_input_path(explainer, wrapped_model, device, input_path, output_dir, + nsamples=200, batch_size=200): + if not os.path.exists(output_dir): os.makedirs(output_dir) + + image_paths = [] + if os.path.isfile(input_path): image_paths.append(input_path) + elif os.path.isdir(input_path): + for ext in ('*.png', '*.jpg', '*.jpeg'): + image_paths.extend(glob(os.path.join(input_path, ext))) + + total = len(image_paths) + print(f"Found {total} images. Starting generation...", flush=True) + print(f"SHAP params: nsamples={nsamples}, batch_size={batch_size}", flush=True) + + start_time = time.time() + success = 0 + failed = 0 + + for idx, image_path in enumerate(image_paths, start=1): + iter_start = time.time() + try: + base_name = os.path.basename(image_path) + file_name = os.path.splitext(base_name)[0] + output_path = os.path.join(output_dir, f"{file_name}.npz") + + # --- Load Image --- + image_tensor = load_image(image_path).to(device) + + # --- Predict --- + with torch.no_grad(): + output = wrapped_model(image_tensor) + prediction = torch.argmax(output, dim=1).cpu().item() + probabilities = output[0].cpu().numpy() + + # --- SHAP --- + image_tensor.requires_grad = True + if not isinstance(explainer, shap.GradientExplainer): + explainer = shap.GradientExplainer(wrapped_model, explainer.data) + + # Quality-preserving speedup: same nsamples, but feed GPU in one large batch. + shap_values = explainer.shap_values(image_tensor, nsamples=nsamples) + + # Shape Correction logic for multi-class/single-class outputs + if isinstance(shap_values, list): shap_values = shap_values[prediction] + if len(shap_values.shape) == 5: shap_values = shap_values[0, :, :, :, prediction] + elif len(shap_values.shape) == 4 and shap_values.shape[0] == 1: shap_values = shap_values[0] + + # --- Save Data --- + save_shap_data(image_tensor, shap_values, prediction, probabilities, output_path) + success += 1 + + # --- Progress Log --- + iter_time = time.time() - iter_start + elapsed = time.time() - start_time + avg = elapsed / idx + remaining = avg * (total - idx) + pct = 100.0 * idx / total + print( + f"[{idx}/{total}] ({pct:5.1f}%) {base_name} " + f"| iter: {iter_time:.2f}s | avg: {avg:.2f}s " + f"| elapsed: {_format_duration(elapsed)} " + f"| eta: {_format_duration(remaining)}", + flush=True, + ) + + except Exception as e: + failed += 1 + print(f"[{idx}/{total}] FAILED {image_path}: {e}", flush=True) + + total_time = time.time() - start_time + print("=" * 60, flush=True) + print( + f"Done. success={success}, failed={failed}, total={total} " + f"| total_time={_format_duration(total_time)} " + f"| avg={total_time / max(total, 1):.2f}s/img", + flush=True, + ) + print("=" * 60, flush=True) + +# ----------------------------------------------------------------------------- +# Main Execution +# ----------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser(description='SHAP Data Generator') + subparsers = parser.add_subparsers(dest='command', required=True) + + # Common parameters for model setup + model_p = argparse.ArgumentParser(add_help=False) + model_p.add_argument('--model_path', type=str, required=True, help='Path to the .pth model file') + model_p.add_argument('--fc1', type=int, default=512) + model_p.add_argument('--fc2', type=int, default=256) + model_p.add_argument('--dropout', type=float, default=0.2) + model_p.add_argument('--cuda_selection', type=int, default=0, help='GPU ID to use') + model_p.add_argument('--num_background', type=int, default=50, help='Number of samples for SHAP background') + model_p.add_argument('--explainer_path', type=str, default='shap_explainer_gen.pkl', help='Path to save/load explainer') + + # Common parameters for running analysis + run_p = argparse.ArgumentParser(add_help=False) + run_p.add_argument('--input_path', type=str, required=True, help='Path to image or directory of images') + run_p.add_argument('--output_dir', type=str, required=True, help='Directory to save .npz results') + + # Dummy parameters to maintain compatibility with visualization scripts + run_p.add_argument('--percentile', type=int, default=95, help='Ignored in generator mode') + run_p.add_argument('--alpha', type=float, default=0.5, help='Ignored in generator mode') + + # SHAP sampling params (quality-preserving speedup: keep nsamples=200, raise batch_size) + run_p.add_argument('--nsamples', type=int, default=200, + help='SHAP nsamples (default 200, matches SHAP default — do not lower to preserve quality)') + run_p.add_argument('--batch_size', type=int, default=200, + help='SHAP internal batch size for GPU utilization (higher = faster, uses more GPU memory)') + + # Commands: prepare, run, prepare-and-run + subparsers.add_parser('prepare', parents=[model_p], help='Initialize and save the SHAP explainer') + + run_parser = subparsers.add_parser('run', parents=[run_p], help='Run generation using existing explainer') + run_parser.add_argument('--explainer_path', type=str, required=True) + + subparsers.add_parser('prepare-and-run', parents=[model_p, run_p], help='Initialize explainer and run generation') + + args = parser.parse_args() + + if args.command == 'prepare': + prepare_explainer(args.model_path, args.fc1, args.fc2, args.dropout, + args.num_background, args.cuda_selection, args.explainer_path) + + elif args.command == 'run': + explainer, wrapped_model, device = load_explainer(args.explainer_path) + if explainer: + process_input_path(explainer, wrapped_model, device, args.input_path, args.output_dir, + nsamples=args.nsamples, batch_size=args.batch_size) + + elif args.command == 'prepare-and-run': + explainer, wrapped_model, device = prepare_explainer( + args.model_path, args.fc1, args.fc2, args.dropout, + args.num_background, args.cuda_selection, args.explainer_path) + if explainer: + process_input_path(explainer, wrapped_model, device, args.input_path, args.output_dir, + nsamples=args.nsamples, batch_size=args.batch_size) + +if __name__ == '__main__': + main() \ No newline at end of file diff --git a/shap_generator_vit.py b/shap_generator_vit.py new file mode 100644 index 0000000..8e5530a --- /dev/null +++ b/shap_generator_vit.py @@ -0,0 +1,366 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +SHAP Data Generator (ViT / Swinv2 variant) +========================================== +Same purpose as shap_generator.py but targets the ViT model +(Swinv2ForImageClassification) defined in models.py instead of CNNModel1. + +Output File Content (keys), identical to the CNN/YOLO variants: +- 'shap_values': SHAP values in (C, H, W) format. +- 'image': Original normalized image in (C, H, W) format. +- 'prediction': int (Predicted class index) +- 'probabilities': array (Class probabilities) +""" + +""" +EXAMPLE USAGE: +-------------- + +1. Prepare and Run (Complete Workflow): +CUDA_VISIBLE_DEVICES=0 python shap_generator_vit.py prepare-and-run \ + --model_path "/path/to/your/vit_checkpoint.pth" \ + --num_classes 2 \ + --window_size 7 --hidden_size 768 --embed_dim 96 \ + --depths 2 2 6 2 --mlp_ratio 4 --encoder_stride 32 \ + --att_drop 0 --drop_path_rate 0.1 --dropout 0.5 --layer_norm_eps 1e-5 \ + --num_background 50 \ + --explainer_path "vit_explainer.pkl" \ + --input_path "/path/to/images_folder" \ + --output_dir "./output_results" \ + --cuda_selection 0 + +2. Prepare Only: +python shap_generator_vit.py prepare \ + --model_path "/path/to/vit_checkpoint.pth" \ + --explainer_path "vit_explainer.pkl" + +3. Run Only: +python shap_generator_vit.py run \ + --explainer_path "vit_explainer.pkl" \ + --input_path "/path/to/single_image.png" \ + --output_dir "./output_results" +""" + +import os +import time +import torch +import numpy as np +import cv2 +import argparse +import pickle +from glob import glob +from models import ViT +import warnings +import shap + +warnings.filterwarnings('ignore') + + +def _format_duration(seconds): + """Format seconds as Hh Mm Ss.""" + seconds = int(seconds) + h, rem = divmod(seconds, 3600) + m, s = divmod(rem, 60) + if h > 0: + return f"{h}h {m}m {s}s" + if m > 0: + return f"{m}m {s}s" + return f"{s}s" + +# ----------------------------------------------------------------------------- +# Data Saving Function +# ----------------------------------------------------------------------------- + +def save_shap_data(image_tensor, shap_values, prediction, probabilities, output_path): + """ + Saves data in compressed numpy format (.npz). + """ + if isinstance(image_tensor, torch.Tensor): + img_np = image_tensor.cpu().detach().numpy() + img_np = img_np[0] + else: + img_np = image_tensor + + if isinstance(shap_values, torch.Tensor): + shap_values = shap_values.cpu().detach().numpy() + + np.savez_compressed( + output_path, + shap_values=shap_values, + image=img_np, + prediction=prediction, + probabilities=probabilities + ) + +# ----------------------------------------------------------------------------- +# Core Functions +# ----------------------------------------------------------------------------- + +class ModelWrapper(torch.nn.Module): + """ + Wraps the classifier so that SHAP receives softmax probabilities + and a clean single-tensor forward signature. + """ + def __init__(self, model): + super(ModelWrapper, self).__init__() + self.model = model + + def forward(self, x): + output = self.model(x) + # ViT.forward returns logits (a tensor) when return_attention=False, + # but guard against tuple/list outputs in case of upstream changes. + if isinstance(output, (tuple, list)): + output = output[0] + return torch.nn.functional.softmax(output, dim=1) + +def create_background_dataset(model, device, num_samples=50, img_size=300): + print(f"Creating background dataset with {num_samples} samples...", flush=True) + background_images = [] + for _ in range(num_samples): + img = np.ones((img_size, img_size, 3), dtype=np.uint8) * 255 + img = img.astype(np.float32) / 255.0 + img = img.transpose((2, 0, 1)) + background_images.append(img) + return torch.FloatTensor(np.array(background_images)).to(device) + +def _load_vit_state_dict(model, checkpoint): + """ + Train pipeline saves checkpoints as {'model_state_dict': ..., 'optimizer_state_dict': ..., ...} + (see train_deepscreen.py). Plain state_dicts are also supported for flexibility. + """ + if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: + state_dict = checkpoint['model_state_dict'] + else: + state_dict = checkpoint + model.load_state_dict(state_dict) + +def prepare_explainer(model_path, num_classes, vit_cfg, num_background, + cuda_selection, save_path, img_size=300): + print("="*60); print("Mode: Preparing SHAP Generator (ViT / Swinv2)"); print("="*60, flush=True) + device = f'cuda:{cuda_selection}' if torch.cuda.is_available() else 'cpu' + print(f"Device: {device}", flush=True) + + try: + model = ViT( + window_size=vit_cfg['window_size'], + hidden_size=vit_cfg['hidden_size'], + att_drop=vit_cfg['att_drop'], + drop_path_rate=vit_cfg['drop_path_rate'], + drop_rate=vit_cfg['drop_rate'], + layer_norm_eps=vit_cfg['layer_norm_eps'], + encoder_stride=vit_cfg['encoder_stride'], + embed_dim=vit_cfg['embed_dim'], + depths=vit_cfg['depths'], + mlp_ratio=vit_cfg['mlp_ratio'], + num_classes=num_classes, + ).to(device) + checkpoint = torch.load(model_path, map_location=device) + _load_vit_state_dict(model, checkpoint) + model.eval() + except Exception as e: + print(f"Error loading model: {e}"); return None, None, None + + wrapped_model = ModelWrapper(model).to(device) + background_images = create_background_dataset(model, device, + num_samples=num_background, + img_size=img_size) + + explainer = shap.GradientExplainer(wrapped_model, background_images) + + explainer_data = { + 'explainer': explainer, + 'model_path': model_path, + 'num_classes': num_classes, + 'vit_cfg': vit_cfg, + 'img_size': img_size, + 'device': device, + 'wrapped_model': wrapped_model, + } + + with open(save_path, 'wb') as f: + pickle.dump(explainer_data, f) + return explainer, wrapped_model, device + +def load_explainer(explainer_path): + if not os.path.exists(explainer_path): return None, None, None, None + with open(explainer_path, 'rb') as f: + data = pickle.load(f) + data['wrapped_model'].to(data['device']) + data['wrapped_model'].eval() + return data['explainer'], data['wrapped_model'], data['device'], data.get('img_size', 300) + +def load_image(image_path, img_size=300): + img = cv2.imread(image_path) + if img is None: raise FileNotFoundError(f"Error: Could not load {image_path}") + if img.shape[:2] != (img_size, img_size): + img = cv2.resize(img, (img_size, img_size)) + + img_normalized = img.astype(np.float32) / 255.0 + img_normalized = img_normalized.transpose((2, 0, 1)) + return torch.FloatTensor(img_normalized).unsqueeze(0) + +def process_input_path(explainer, wrapped_model, device, input_path, output_dir, + nsamples=200, batch_size=200, img_size=300): + if not os.path.exists(output_dir): os.makedirs(output_dir) + + image_paths = [] + if os.path.isfile(input_path): image_paths.append(input_path) + elif os.path.isdir(input_path): + for ext in ('*.png', '*.jpg', '*.jpeg'): + image_paths.extend(glob(os.path.join(input_path, ext))) + + total = len(image_paths) + print(f"Found {total} images. Starting generation...", flush=True) + print(f"SHAP params: nsamples={nsamples}, batch_size={batch_size}, img_size={img_size}", flush=True) + + start_time = time.time() + success = 0 + failed = 0 + + for idx, image_path in enumerate(image_paths, start=1): + iter_start = time.time() + try: + base_name = os.path.basename(image_path) + file_name = os.path.splitext(base_name)[0] + output_path = os.path.join(output_dir, f"{file_name}.npz") + + image_tensor = load_image(image_path, img_size=img_size).to(device) + + with torch.no_grad(): + output = wrapped_model(image_tensor) + prediction = torch.argmax(output, dim=1).cpu().item() + probabilities = output[0].cpu().numpy() + + image_tensor.requires_grad = True + if not isinstance(explainer, shap.GradientExplainer): + explainer = shap.GradientExplainer(wrapped_model, explainer.data) + + shap_values = explainer.shap_values(image_tensor, nsamples=nsamples) + + if isinstance(shap_values, list): shap_values = shap_values[prediction] + if len(shap_values.shape) == 5: shap_values = shap_values[0, :, :, :, prediction] + elif len(shap_values.shape) == 4 and shap_values.shape[0] == 1: shap_values = shap_values[0] + + save_shap_data(image_tensor, shap_values, prediction, probabilities, output_path) + success += 1 + + iter_time = time.time() - iter_start + elapsed = time.time() - start_time + avg = elapsed / idx + remaining = avg * (total - idx) + pct = 100.0 * idx / total + print( + f"[{idx}/{total}] ({pct:5.1f}%) {base_name} " + f"| iter: {iter_time:.2f}s | avg: {avg:.2f}s " + f"| elapsed: {_format_duration(elapsed)} " + f"| eta: {_format_duration(remaining)}", + flush=True, + ) + + except Exception as e: + failed += 1 + print(f"[{idx}/{total}] FAILED {image_path}: {e}", flush=True) + + total_time = time.time() - start_time + print("=" * 60, flush=True) + print( + f"Done. success={success}, failed={failed}, total={total} " + f"| total_time={_format_duration(total_time)} " + f"| avg={total_time / max(total, 1):.2f}s/img", + flush=True, + ) + print("=" * 60, flush=True) + +# ----------------------------------------------------------------------------- +# Main Execution +# ----------------------------------------------------------------------------- + +def _build_vit_cfg(args): + return { + 'window_size': args.window_size, + 'hidden_size': args.hidden_size, + 'att_drop': args.att_drop, + 'drop_path_rate': args.drop_path_rate, + 'drop_rate': args.dropout, + 'layer_norm_eps': args.layer_norm_eps, + 'encoder_stride': args.encoder_stride, + 'embed_dim': args.embed_dim, + 'depths': args.depths, + 'mlp_ratio': args.mlp_ratio, + } + +def main(): + parser = argparse.ArgumentParser(description='SHAP Data Generator (ViT / Swinv2)') + subparsers = parser.add_subparsers(dest='command', required=True) + + # Common parameters for model setup + model_p = argparse.ArgumentParser(add_help=False) + model_p.add_argument('--model_path', type=str, required=True, help='Path to the .pth checkpoint file') + model_p.add_argument('--num_classes', type=int, default=2, help='Number of output classes (default: 2)') + model_p.add_argument('--cuda_selection', type=int, default=0, help='GPU ID to use') + model_p.add_argument('--num_background', type=int, default=50, help='Number of samples for SHAP background') + model_p.add_argument('--explainer_path', type=str, default='shap_explainer_vit.pkl', help='Path to save/load explainer') + model_p.add_argument('--img_size', type=int, default=300, help='Input image size (default: 300, matches DEEPScreen pipeline)') + + # ViT (Swinv2) architecture params — MUST match the trained checkpoint. + # Defaults mirror config/sweep_vit.yaml. + model_p.add_argument('--window_size', type=int, default=7, help='Swinv2 window size') + model_p.add_argument('--hidden_size', type=int, default=768, help='Swinv2 hidden size') + model_p.add_argument('--embed_dim', type=int, default=96, help='Swinv2 embedding dim') + model_p.add_argument('--depths', type=int, nargs='+', default=[2, 2, 6, 2], help='Swinv2 depths per stage') + model_p.add_argument('--mlp_ratio', type=float, default=4, help='Swinv2 MLP ratio') + model_p.add_argument('--encoder_stride', type=int, default=32, help='Swinv2 encoder stride') + model_p.add_argument('--att_drop', type=float, default=0.0, help='Attention probs dropout prob') + model_p.add_argument('--drop_path_rate', type=float, default=0.1, help='Stochastic depth drop path rate') + model_p.add_argument('--dropout', type=float, default=0.5, help='Hidden dropout prob (drop_rate)') + model_p.add_argument('--layer_norm_eps', type=float, default=1e-5, help='LayerNorm epsilon') + + # Common parameters for running analysis + run_p = argparse.ArgumentParser(add_help=False) + run_p.add_argument('--input_path', type=str, required=True, help='Path to image or directory of images') + run_p.add_argument('--output_dir', type=str, required=True, help='Directory to save .npz results') + + # Dummy parameters to maintain compatibility with visualization scripts + run_p.add_argument('--percentile', type=int, default=95, help='Ignored in generator mode') + run_p.add_argument('--alpha', type=float, default=0.5, help='Ignored in generator mode') + + run_p.add_argument('--nsamples', type=int, default=200, + help='SHAP nsamples (default 200)') + run_p.add_argument('--batch_size', type=int, default=200, + help='SHAP internal batch size for GPU utilization') + + subparsers.add_parser('prepare', parents=[model_p], help='Initialize and save the SHAP explainer') + + run_parser = subparsers.add_parser('run', parents=[run_p], help='Run generation using existing explainer') + run_parser.add_argument('--explainer_path', type=str, required=True) + + subparsers.add_parser('prepare-and-run', parents=[model_p, run_p], help='Initialize explainer and run generation') + + args = parser.parse_args() + + if args.command == 'prepare': + prepare_explainer(args.model_path, args.num_classes, _build_vit_cfg(args), + args.num_background, args.cuda_selection, + args.explainer_path, img_size=args.img_size) + + elif args.command == 'run': + explainer, wrapped_model, device, img_size = load_explainer(args.explainer_path) + if explainer: + process_input_path(explainer, wrapped_model, device, args.input_path, args.output_dir, + nsamples=args.nsamples, batch_size=args.batch_size, + img_size=img_size) + + elif args.command == 'prepare-and-run': + explainer, wrapped_model, device = prepare_explainer( + args.model_path, args.num_classes, _build_vit_cfg(args), + args.num_background, args.cuda_selection, + args.explainer_path, img_size=args.img_size) + if explainer: + process_input_path(explainer, wrapped_model, device, args.input_path, args.output_dir, + nsamples=args.nsamples, batch_size=args.batch_size, + img_size=args.img_size) + +if __name__ == '__main__': + main() diff --git a/shap_generator_yolo.py b/shap_generator_yolo.py new file mode 100644 index 0000000..283a951 --- /dev/null +++ b/shap_generator_yolo.py @@ -0,0 +1,326 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- + +""" +SHAP Data Generator (YOLOv11 Classifier variant) +================================================ +Same purpose as shap_generator.py but targets the YOLOv11Classifier model +defined in models.py instead of CNNModel1. + +Output File Content (keys), identical to the CNN variant: +- 'shap_values': SHAP values in (C, H, W) format. +- 'image': Original normalized image in (C, H, W) format. +- 'prediction': int (Predicted class index) +- 'probabilities': array (Class probabilities) +""" + +""" +EXAMPLE USAGE: +-------------- + +1. Prepare and Run (Complete Workflow): +CUDA_VISIBLE_DEVICES=0 python shap_generator_yolo.py prepare-and-run \ + --model_path "/path/to/your/yolo_checkpoint.pth" \ + --num_classes 2 --model_size "yolo11m" \ + --num_background 50 \ + --explainer_path "yolo_explainer.pkl" \ + --input_path "/path/to/images_folder" \ + --output_dir "./output_results" \ + --cuda_selection 0 + +2. Prepare Only: +python shap_generator_yolo.py prepare \ + --model_path "/path/to/yolo_checkpoint.pth" \ + --explainer_path "yolo_explainer.pkl" + +3. Run Only: +python shap_generator_yolo.py run \ + --explainer_path "yolo_explainer.pkl" \ + --input_path "/path/to/single_image.png" \ + --output_dir "./output_results" +""" + +import os +import time +import torch +import numpy as np +import cv2 +import argparse +import pickle +from glob import glob +from models import YOLOv11Classifier +import warnings +import shap + +warnings.filterwarnings('ignore') + + +def _format_duration(seconds): + """Format seconds as Hh Mm Ss.""" + seconds = int(seconds) + h, rem = divmod(seconds, 3600) + m, s = divmod(rem, 60) + if h > 0: + return f"{h}h {m}m {s}s" + if m > 0: + return f"{m}m {s}s" + return f"{s}s" + +# ----------------------------------------------------------------------------- +# Data Saving Function +# ----------------------------------------------------------------------------- + +def save_shap_data(image_tensor, shap_values, prediction, probabilities, output_path): + """ + Saves data in compressed numpy format (.npz). + """ + if isinstance(image_tensor, torch.Tensor): + img_np = image_tensor.cpu().detach().numpy() + img_np = img_np[0] + else: + img_np = image_tensor + + if isinstance(shap_values, torch.Tensor): + shap_values = shap_values.cpu().detach().numpy() + + np.savez_compressed( + output_path, + shap_values=shap_values, + image=img_np, + prediction=prediction, + probabilities=probabilities + ) + +# ----------------------------------------------------------------------------- +# Core Functions +# ----------------------------------------------------------------------------- + +class ModelWrapper(torch.nn.Module): + """ + Wraps the classifier so that SHAP receives softmax probabilities + and a clean single-tensor forward signature. + """ + def __init__(self, model): + super(ModelWrapper, self).__init__() + self.model = model + + def forward(self, x): + output = self.model(x) + # YOLOv11Classifier.forward already unwraps tuple/list outputs, + # but guard again here in case of upstream changes. + if isinstance(output, (tuple, list)): + output = output[0] + return torch.nn.functional.softmax(output, dim=1) + +def create_background_dataset(model, device, num_samples=50, img_size=300): + print(f"Creating background dataset with {num_samples} samples...", flush=True) + background_images = [] + for _ in range(num_samples): + img = np.ones((img_size, img_size, 3), dtype=np.uint8) * 255 + img = img.astype(np.float32) / 255.0 + img = img.transpose((2, 0, 1)) + background_images.append(img) + return torch.FloatTensor(np.array(background_images)).to(device) + +def _load_yolo_state_dict(model, checkpoint): + """ + Train pipeline saves checkpoints as {'model_state_dict': ..., 'optimizer_state_dict': ..., ...} + (see train_deepscreen.py). Plain state_dicts are also supported for flexibility. + """ + if isinstance(checkpoint, dict) and 'model_state_dict' in checkpoint: + state_dict = checkpoint['model_state_dict'] + else: + state_dict = checkpoint + model.load_state_dict(state_dict) + +def prepare_explainer(model_path, num_classes, model_size, num_background, + cuda_selection, save_path, img_size=300): + print("="*60); print("Mode: Preparing SHAP Generator (YOLOv11)"); print("="*60, flush=True) + device = f'cuda:{cuda_selection}' if torch.cuda.is_available() else 'cpu' + print(f"Device: {device}", flush=True) + + try: + model = YOLOv11Classifier(num_classes=num_classes, model_size=model_size).to(device) + checkpoint = torch.load(model_path, map_location=device) + _load_yolo_state_dict(model, checkpoint) + model.eval() + except Exception as e: + print(f"Error loading model: {e}"); return None, None, None + + wrapped_model = ModelWrapper(model).to(device) + background_images = create_background_dataset(model, device, + num_samples=num_background, + img_size=img_size) + + explainer = shap.GradientExplainer(wrapped_model, background_images) + + explainer_data = { + 'explainer': explainer, + 'model_path': model_path, + 'num_classes': num_classes, + 'model_size': model_size, + 'img_size': img_size, + 'device': device, + 'wrapped_model': wrapped_model, + } + + with open(save_path, 'wb') as f: + pickle.dump(explainer_data, f) + return explainer, wrapped_model, device + +def load_explainer(explainer_path): + if not os.path.exists(explainer_path): return None, None, None, None + with open(explainer_path, 'rb') as f: + data = pickle.load(f) + data['wrapped_model'].to(data['device']) + data['wrapped_model'].eval() + return data['explainer'], data['wrapped_model'], data['device'], data.get('img_size', 300) + +def load_image(image_path, img_size=300): + img = cv2.imread(image_path) + if img is None: raise FileNotFoundError(f"Error: Could not load {image_path}") + if img.shape[:2] != (img_size, img_size): + img = cv2.resize(img, (img_size, img_size)) + + img_normalized = img.astype(np.float32) / 255.0 + img_normalized = img_normalized.transpose((2, 0, 1)) + return torch.FloatTensor(img_normalized).unsqueeze(0) + +def process_input_path(explainer, wrapped_model, device, input_path, output_dir, + nsamples=200, batch_size=200, img_size=300): + if not os.path.exists(output_dir): os.makedirs(output_dir) + + image_paths = [] + if os.path.isfile(input_path): image_paths.append(input_path) + elif os.path.isdir(input_path): + for ext in ('*.png', '*.jpg', '*.jpeg'): + image_paths.extend(glob(os.path.join(input_path, ext))) + + total = len(image_paths) + print(f"Found {total} images. Starting generation...", flush=True) + print(f"SHAP params: nsamples={nsamples}, batch_size={batch_size}, img_size={img_size}", flush=True) + + start_time = time.time() + success = 0 + failed = 0 + + for idx, image_path in enumerate(image_paths, start=1): + iter_start = time.time() + try: + base_name = os.path.basename(image_path) + file_name = os.path.splitext(base_name)[0] + output_path = os.path.join(output_dir, f"{file_name}.npz") + + image_tensor = load_image(image_path, img_size=img_size).to(device) + + with torch.no_grad(): + output = wrapped_model(image_tensor) + prediction = torch.argmax(output, dim=1).cpu().item() + probabilities = output[0].cpu().numpy() + + image_tensor.requires_grad = True + if not isinstance(explainer, shap.GradientExplainer): + explainer = shap.GradientExplainer(wrapped_model, explainer.data) + + shap_values = explainer.shap_values(image_tensor, nsamples=nsamples) + + if isinstance(shap_values, list): shap_values = shap_values[prediction] + if len(shap_values.shape) == 5: shap_values = shap_values[0, :, :, :, prediction] + elif len(shap_values.shape) == 4 and shap_values.shape[0] == 1: shap_values = shap_values[0] + + save_shap_data(image_tensor, shap_values, prediction, probabilities, output_path) + success += 1 + + iter_time = time.time() - iter_start + elapsed = time.time() - start_time + avg = elapsed / idx + remaining = avg * (total - idx) + pct = 100.0 * idx / total + print( + f"[{idx}/{total}] ({pct:5.1f}%) {base_name} " + f"| iter: {iter_time:.2f}s | avg: {avg:.2f}s " + f"| elapsed: {_format_duration(elapsed)} " + f"| eta: {_format_duration(remaining)}", + flush=True, + ) + + except Exception as e: + failed += 1 + print(f"[{idx}/{total}] FAILED {image_path}: {e}", flush=True) + + total_time = time.time() - start_time + print("=" * 60, flush=True) + print( + f"Done. success={success}, failed={failed}, total={total} " + f"| total_time={_format_duration(total_time)} " + f"| avg={total_time / max(total, 1):.2f}s/img", + flush=True, + ) + print("=" * 60, flush=True) + +# ----------------------------------------------------------------------------- +# Main Execution +# ----------------------------------------------------------------------------- + +def main(): + parser = argparse.ArgumentParser(description='SHAP Data Generator (YOLOv11)') + subparsers = parser.add_subparsers(dest='command', required=True) + + # Common parameters for model setup + model_p = argparse.ArgumentParser(add_help=False) + model_p.add_argument('--model_path', type=str, required=True, help='Path to the .pth checkpoint file') + model_p.add_argument('--num_classes', type=int, default=2, help='Number of output classes (default: 2)') + model_p.add_argument('--model_size', type=str, default='yolo11m', + help='Ultralytics YOLO classifier size, e.g. yolo11n / yolo11s / yolo11m / yolo11l / yolo11x (default: yolo11m, matches train_deepscreen.py)') + model_p.add_argument('--cuda_selection', type=int, default=0, help='GPU ID to use') + model_p.add_argument('--num_background', type=int, default=50, help='Number of samples for SHAP background') + model_p.add_argument('--explainer_path', type=str, default='shap_explainer_yolo.pkl', help='Path to save/load explainer') + model_p.add_argument('--img_size', type=int, default=300, help='Input image size (default: 300, matches DEEPScreen pipeline)') + + # Common parameters for running analysis + run_p = argparse.ArgumentParser(add_help=False) + run_p.add_argument('--input_path', type=str, required=True, help='Path to image or directory of images') + run_p.add_argument('--output_dir', type=str, required=True, help='Directory to save .npz results') + + # Dummy parameters to maintain compatibility with visualization scripts + run_p.add_argument('--percentile', type=int, default=95, help='Ignored in generator mode') + run_p.add_argument('--alpha', type=float, default=0.5, help='Ignored in generator mode') + + run_p.add_argument('--nsamples', type=int, default=200, + help='SHAP nsamples (default 200)') + run_p.add_argument('--batch_size', type=int, default=200, + help='SHAP internal batch size for GPU utilization') + + subparsers.add_parser('prepare', parents=[model_p], help='Initialize and save the SHAP explainer') + + run_parser = subparsers.add_parser('run', parents=[run_p], help='Run generation using existing explainer') + run_parser.add_argument('--explainer_path', type=str, required=True) + + subparsers.add_parser('prepare-and-run', parents=[model_p, run_p], help='Initialize explainer and run generation') + + args = parser.parse_args() + + if args.command == 'prepare': + prepare_explainer(args.model_path, args.num_classes, args.model_size, + args.num_background, args.cuda_selection, + args.explainer_path, img_size=args.img_size) + + elif args.command == 'run': + explainer, wrapped_model, device, img_size = load_explainer(args.explainer_path) + if explainer: + process_input_path(explainer, wrapped_model, device, args.input_path, args.output_dir, + nsamples=args.nsamples, batch_size=args.batch_size, + img_size=img_size) + + elif args.command == 'prepare-and-run': + explainer, wrapped_model, device = prepare_explainer( + args.model_path, args.num_classes, args.model_size, + args.num_background, args.cuda_selection, + args.explainer_path, img_size=args.img_size) + if explainer: + process_input_path(explainer, wrapped_model, device, args.input_path, args.output_dir, + nsamples=args.nsamples, batch_size=args.batch_size, + img_size=args.img_size) + +if __name__ == '__main__': + main() From 062c9a3e74177140792c2e34639f9e53a591881a Mon Sep 17 00:00:00 2001 From: furkannecatiinan Date: Mon, 22 Jun 2026 22:09:47 +0300 Subject: [PATCH 2/4] shaps folder --- shap_generator_cnn.py => shap/shap_generator_cnn.py | 0 shap_generator_vit.py => shap/shap_generator_vit.py | 0 shap_generator_yolo.py => shap/shap_generator_yolo.py | 0 3 files changed, 0 insertions(+), 0 deletions(-) rename shap_generator_cnn.py => shap/shap_generator_cnn.py (100%) rename shap_generator_vit.py => shap/shap_generator_vit.py (100%) rename shap_generator_yolo.py => shap/shap_generator_yolo.py (100%) diff --git a/shap_generator_cnn.py b/shap/shap_generator_cnn.py similarity index 100% rename from shap_generator_cnn.py rename to shap/shap_generator_cnn.py diff --git a/shap_generator_vit.py b/shap/shap_generator_vit.py similarity index 100% rename from shap_generator_vit.py rename to shap/shap_generator_vit.py diff --git a/shap_generator_yolo.py b/shap/shap_generator_yolo.py similarity index 100% rename from shap_generator_yolo.py rename to shap/shap_generator_yolo.py From cc37ccdedb4c7dcf50467716fd0923a895d47cf6 Mon Sep 17 00:00:00 2001 From: furkannecatiinan Date: Mon, 22 Jun 2026 22:16:51 +0300 Subject: [PATCH 3/4] shap generator helpers --- shap/USAGE.md | 84 ++++++++++++ shap/generate_images.py | 152 ++++++++++++++++++++++ shap/prepare_csv.py | 60 +++++++++ shap/shap_cli.py | 280 ++++++++++++++++++++++++++++++++++++++++ 4 files changed, 576 insertions(+) create mode 100644 shap/USAGE.md create mode 100644 shap/generate_images.py create mode 100644 shap/prepare_csv.py create mode 100644 shap/shap_cli.py diff --git a/shap/USAGE.md b/shap/USAGE.md new file mode 100644 index 0000000..16885e4 --- /dev/null +++ b/shap/USAGE.md @@ -0,0 +1,84 @@ +# SHAP Vis — Usage + +End-to-end pipeline that turns SHAP `.npz` outputs into per-molecule attention +PNGs, packed into one ZIP per model. `app8.py` is the interactive Streamlit +version of the same flow. + +## Pipeline overview + +``` +prepare_csv.py -> molecules.csv (fetch SMILES from ChEMBL / PubChem) +generate_images.py -> mol_images/ (transparent 300x300 molecule PNGs) +shap_cli.py -> output_zips/.zip (terminal version of app7/app8 "Download ZIP") +``` + +## Requirements + +```bash +pip install numpy opencv-python matplotlib pillow rdkit pandas requests \ + chembl_webresource_client streamlit +``` + +## 1. Build `molecules.csv` + +Scans `shap_numpy_data/` for `*_0.npz`, collects unique molecule IDs, and looks +up each SMILES (ChEMBL first, PubChem fallback). + +```bash +python prepare_csv.py # writes ./molecules.csv +``` + +CSV columns: `molecule_id, smiles`. + +## 2. Generate molecule images + +Renders transparent-background PNGs from the SMILES list. + +```bash +python generate_images.py --input molecules.csv --output_dir ./mol_images +``` + +Missing PNGs are also generated on demand by `shap_cli.py`, so this step is +optional if `molecules.csv` exists. + +## 3. Render attention ZIPs + +Provide either a single model folder or a parent folder of model subfolders. + +```bash +# single model folder -> CHEMBL301_cnn.zip +python shap_cli.py -i shap_numpy_data/CHEMBL301_cnn -o ./output_zips + +# parent folder -> one ZIP per subfolder +python shap_cli.py -i shap_numpy_data -o ./output_zips +``` + +Key options (defaults match the app7/app8 sliders): + +| Flag | Default | Meaning | +|------|---------|---------| +| `--input`, `-i` | — | NPZ folder or parent of model subfolders (required) | +| `--output_dir`, `-o` | `./output_zips` | Where ZIPs are written | +| `--mol_images_dir` | `./mol_images` | Transparent molecule PNGs | +| `--molecules_csv` | `./molecules.csv` | SMILES source for missing PNGs | +| `--hotspot_p` | `95.0` | Focus threshold (Top %) | +| `--blur_sigma` | `5.7` | Smoothing | +| `--gamma` | `1.50` | Intensity | +| `--alpha` | `0.75` | Opacity | + +## Run everything at once + +```bash +./run_pipeline.sh # defaults +./run_pipeline.sh --input shap_numpy_data +./run_pipeline.sh --skip-csv --skip-images # only re-render the ZIPs +``` + +## Interactive version + +```bash +streamlit run app8.py +``` + +Same parameters exposed as sliders, with a "Download ZIP" button equivalent to +`shap_cli.py`. diff --git a/shap/generate_images.py b/shap/generate_images.py new file mode 100644 index 0000000..25fae8b --- /dev/null +++ b/shap/generate_images.py @@ -0,0 +1,152 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Molekül Görüntüsü Üretici — Saydam Arka Plan +============================================= +Verilen SMILES listesinden 300x300 saydam arka planlı PNG üretir. +Çıktı: ./mol_images/.png + +Kullanım: + python generate_mol_images.py --input molecules.csv --output_dir ./mol_images + +CSV formatı (zorunlu sütunlar): + molecule_id, smiles + CHEMBL100675, CCOc1ccc(...)cc1 +""" + +import os +import argparse +import numpy as np +import pandas as pd +from io import BytesIO + +try: + from rdkit import Chem + from rdkit.Chem import Draw + from rdkit.Chem.Draw import rdMolDraw2D + HAS_RDKIT = True +except ImportError: + HAS_RDKIT = False + print("⚠️ RDKit bulunamadı. 'pip install rdkit' ile kurun.") + +try: + from PIL import Image + HAS_PIL = True +except ImportError: + HAS_PIL = False + print("⚠️ Pillow bulunamadı. 'pip install Pillow' ile kurun.") + + +def smiles_to_transparent_png(smiles: str, size: int = 300) -> "Image": + """ + SMILES'tan saydam arka planlı PIL Image üretir. + Renk bazlı ayırt etme: N (mavi), O (kırmızı), Br (kahve) gibi atom etiketleri + tam opak korunur; beyaz/gri arka plan kademeli olarak saydam yapılır. + """ + mol = Chem.MolFromSmiles(smiles) + if mol is None: + raise ValueError(f"Geçersiz SMILES: {smiles}") + + # RDKit SVG/PNG renderer — beyaz arka plan + drawer = rdMolDraw2D.MolDraw2DCairo(size, size) + drawer.drawOptions().clearBackground = True + drawer.DrawMolecule(mol) + drawer.FinishDrawing() + + png_bytes = drawer.GetDrawingText() + img = Image.open(BytesIO(png_bytes)).convert("RGBA") + + data = np.array(img, dtype=float) + r = data[:, :, 0] + g = data[:, :, 1] + b = data[:, :, 2] + + # Parlaklık: pikselin ne kadar açık olduğu + brightness = (r + g + b) / 3.0 + + # Renklililik: R/G/B kanalları arasındaki maksimum fark + # Yüksekse renkli piksel (N, O, Br atom etiketleri), düşükse gri/beyaz (arka plan) + colorfulness = (np.max(data[:, :, :3], axis=2) - + np.min(data[:, :, :3], axis=2)) + + # Arka plan maskesi: hem açık (brightness > 200) hem renksiz (colorfulness < 15) + is_background = (brightness > 200) & (colorfulness < 15) + + # Alpha hesaplama: + # - Arka plan: brightness'a göre kademeli saydamlık (anti-aliasing kenarları için) + # - Renkli/koyu pikseller: tam opak (255) + background_alpha = np.clip((255.0 - brightness) * 3.0, 0, 255) + alpha_channel = np.where(is_background, background_alpha, 255.0) + + data[:, :, 3] = alpha_channel + return Image.fromarray(data.astype(np.uint8)) + + +def generate_from_csv(input_csv: str, output_dir: str, size: int = 300): + """CSV dosyasından toplu görüntü üretir.""" + os.makedirs(output_dir, exist_ok=True) + df = pd.read_csv(input_csv) + + required = {'molecule_id', 'smiles'} + if not required.issubset(df.columns): + raise ValueError(f"CSV'de şu sütunlar olmalı: {required}. Bulunanlar: {list(df.columns)}") + + success, fail = 0, 0 + for _, row in df.iterrows(): + mol_id = str(row['molecule_id']) + smiles = str(row['smiles']) + out_path = os.path.join(output_dir, f"{mol_id}.png") + + try: + img = smiles_to_transparent_png(smiles, size=size) + img.save(out_path, format="PNG") + print(f"✓ {mol_id} → {out_path}") + success += 1 + except Exception as e: + print(f"✗ {mol_id}: {e}") + fail += 1 + + print(f"\n✅ Tamamlandı: {success} başarılı, {fail} başarısız.") + + +def generate_from_dict(molecules: dict, output_dir: str, size: int = 300): + """ + Dict'ten toplu görüntü üretir. + molecules = {"CHEMBL100675": "CCO...", "DILI1": "c1ccc..."} + """ + os.makedirs(output_dir, exist_ok=True) + success, fail = 0, 0 + + for mol_id, smiles in molecules.items(): + out_path = os.path.join(output_dir, f"{mol_id}.png") + try: + img = smiles_to_transparent_png(smiles, size=size) + img.save(out_path, format="PNG") + print(f"✓ {mol_id} → {out_path}") + success += 1 + except Exception as e: + print(f"✗ {mol_id}: {e}") + fail += 1 + + print(f"\n✅ Tamamlandı: {success} başarılı, {fail} başarısız.") + + +def main(): + parser = argparse.ArgumentParser(description="Saydam arka planlı molekül PNG üretici") + parser.add_argument("--input", type=str, required=True, + help="CSV dosyası (sütunlar: molecule_id, smiles)") + parser.add_argument("--output_dir", type=str, default="./mol_images", + help="Çıktı klasörü (varsayılan: ./mol_images)") + parser.add_argument("--size", type=int, default=300, + help="Görüntü boyutu piksel (varsayılan: 300)") + args = parser.parse_args() + + if not HAS_RDKIT or not HAS_PIL: + print("Eksik bağımlılık. Çıkılıyor.") + return + + generate_from_csv(args.input, args.output_dir, size=args.size) + + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/shap/prepare_csv.py b/shap/prepare_csv.py new file mode 100644 index 0000000..c18d0c2 --- /dev/null +++ b/shap/prepare_csv.py @@ -0,0 +1,60 @@ +import os +import glob +import pandas as pd +import requests +from chembl_webresource_client.new_client import new_client + +def get_smiles_from_chembl(chembl_id): + """ChEMBL ID'den SMILES çeker.""" + try: + molecule = new_client.molecule + res = molecule.filter(molecule_chembl_id=chembl_id).only(['molecule_structures']) + if res: + return res[0]['molecule_structures']['canonical_smiles'] + except Exception: + return None + return None + +def get_smiles_from_pubchem(name): + """DILI veya diğer ID'leri PubChem üzerinden aramayı dener.""" + try: + url = f"https://pubchem.ncbi.nlm.nih.gov/rest/pug/compound/name/{name}/property/CanonicalSMILES/JSON" + response = requests.get(url, timeout=5) + if response.status_status == 200: + return response.json()['PropertyTable']['Properties'][0]['CanonicalSMILES'] + except Exception: + return None + return None + +def update_csv_with_smiles(data_dir, output_csv="molecules.csv"): + print(f"🔍 {data_dir} taranıyor...") + search_pattern = os.path.join(data_dir, "**", "*_0.npz") + files = glob.glob(search_pattern, recursive=True) + + mol_ids = sorted(list(set([os.path.basename(f).replace("_0.npz", "") for f in files]))) + print(f"Bulunan benzersiz molekül sayısı: {len(mol_ids)}") + + results = [] + for m_id in mol_ids: + print(f"📡 Veri çekiliyor: {m_id}...", end=" ", flush=True) + + # 1. Yol: ChEMBL API + smiles = get_smiles_from_chembl(m_id) + + # 2. Yol: PubChem (ChEMBL bulamazsa veya DILI id'si ise) + if not smiles: + smiles = get_smiles_from_pubchem(m_id) + + if smiles: + print("✅ Bulundu") + else: + print("❌ Bulunamadı") + + results.append({"molecule_id": m_id, "smiles": smiles}) + + df = pd.DataFrame(results) + df.to_csv(output_csv, index=False) + print(f"\n📁 İşlem bitti! '{output_csv}' kontrol edebilirsin.") + +if __name__ == "__main__": + update_csv_with_smiles("./shap_numpy_data") \ No newline at end of file diff --git a/shap/shap_cli.py b/shap/shap_cli.py new file mode 100644 index 0000000..e9281e7 --- /dev/null +++ b/shap/shap_cli.py @@ -0,0 +1,280 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +SHAP Visualization CLI +====================== + +When you provide a single model folder (e.g. shap_numpy_data/CHEMBL301_cnn) it +generates the processed PNGs for every molecule in that folder and packs them +into a ZIP. + +When you provide a parent folder (e.g. shap_numpy_data) it processes each model +subfolder underneath it separately and produces a separate ZIP for each one +(e.g. CHEMBL301_cnn.zip, CHEMBL301_yolo.zip). +""" + +import argparse +import os +import re +import sys +import zipfile +from glob import glob +from io import BytesIO + +import cv2 +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from matplotlib.colors import LinearSegmentedColormap +from PIL import Image + +cmap_rw = LinearSegmentedColormap.from_list("white_red", ["white", "red"]) + + +def apply_attention_style(heatmap, sigma=3): + if sigma > 0: + heatmap = cv2.GaussianBlur(heatmap, (0, 0), sigmaX=sigma, sigmaY=sigma) + return heatmap + + +def robust_normalize(data, p_low=1, p_high=99): + vmin = np.percentile(data, p_low) + vmax = np.percentile(data, p_high) + if vmax - vmin < 1e-9: + return np.zeros_like(data) + data = np.clip(data, vmin, vmax) + return (data - vmin) / (vmax - vmin) + + +def rotate_image_back(image, angle_degrees, is_shap=False): + if angle_degrees == 0: + return image + h, w = image.shape[:2] + center = (w // 2, h // 2) + M = cv2.getRotationMatrix2D(center, -angle_degrees, 1.0) + border_val = 0 if is_shap else (1.0, 1.0, 1.0) + return cv2.warpAffine( + image, M, (w, h), + flags=cv2.INTER_LINEAR, + borderMode=cv2.BORDER_CONSTANT, + borderValue=border_val, + ) + + +def load_and_aggregate_data(base_name, input_dir): + all_files = glob(os.path.join(input_dir, "*.npz")) + pattern = re.compile(re.escape(base_name) + r"_(\d+)\.npz$") + matched_files = [] + for f in all_files: + match = pattern.search(os.path.basename(f)) + if match: + matched_files.append((f, int(match.group(1)))) + if not matched_files: + return None + + shap_list, img_list, probs_list = [], [], [] + for f_path, angle in matched_files: + data = np.load(f_path) + raw_shap = data["shap_values"].transpose((1, 2, 0)) + raw_img = data["image"].transpose((1, 2, 0)) + shap_list.append(rotate_image_back(raw_shap, angle, is_shap=True)) + img_list.append(rotate_image_back(raw_img, angle, is_shap=False)) + probs_list.append(data["probabilities"]) + + return { + "shap": np.mean(shap_list, axis=0), + "image": np.mean(img_list, axis=0), + "probs": np.mean(probs_list, axis=0), + } + + +def load_smiles_lookup(molecules_csv_path): + if not molecules_csv_path or not os.path.exists(molecules_csv_path): + return {} + import csv + lookup = {} + with open(molecules_csv_path, "r", encoding="utf-8") as f: + reader = csv.DictReader(f) + for row in reader: + mol_id = str(row.get("molecule_id", "")).strip() + smiles = str(row.get("smiles", "")).strip() + if mol_id and smiles and smiles.lower() != "nan": + lookup[mol_id] = smiles + return lookup + + +def generate_missing_transparent_png(base_name, mol_images_dir, smiles_lookup, target_size=300): + smiles = smiles_lookup.get(base_name) + if not smiles: + return None + try: + from generate_images import smiles_to_transparent_png + except Exception: + return None + os.makedirs(mol_images_dir, exist_ok=True) + out_path = os.path.join(mol_images_dir, f"{base_name}.png") + try: + img = smiles_to_transparent_png(smiles, size=target_size) + img.save(out_path, format="PNG") + return out_path + except Exception: + return None + + +def load_transparent_mol_image(mol_images_dir, base_name, smiles_lookup, target_size=300): + path = os.path.join(mol_images_dir, f"{base_name}.png") + if not os.path.exists(path): + generated = generate_missing_transparent_png(base_name, mol_images_dir, smiles_lookup, target_size) + if generated is None: + return None + path = generated + img = Image.open(path).convert("RGBA") + if img.size != (target_size, target_size): + img = img.resize((target_size, target_size), Image.LANCZOS) + return np.array(img).astype(np.float32) / 255.0 + + +def compute_shap_norm(data, blur_sigma, gamma, hotspot_p): + shap_sum = np.abs(data["shap"]).sum(axis=2) + shap_smooth = apply_attention_style(shap_sum, sigma=blur_sigma) + thresh = np.percentile(shap_smooth, hotspot_p) + shap_clipped = np.where(shap_smooth > thresh, shap_smooth, thresh) + shap_norm = robust_normalize(shap_clipped) + shap_norm = np.power(shap_norm, gamma) + return shap_norm + + +def create_download_plot(data, blur_sigma, gamma, hotspot_p, alpha, mol_rgba=None): + """Identical to the Download ZIP flow in app7.py.""" + if mol_rgba is not None: + mol_rgb = mol_rgba[:, :, :3] + else: + mol_rgb = np.clip(data["image"], 0, 1) + + shap_norm = compute_shap_norm(data, blur_sigma, gamma, hotspot_p) + heatmap_rgb = cmap_rw(shap_norm)[..., :3] + final_combined = heatmap_rgb * mol_rgb + + fig, ax = plt.subplots(1, 1, figsize=(10, 10)) + ax.imshow(final_combined) + ax.axis("off") + plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) + plt.margins(0, 0) + return fig + + +def collect_base_names(input_dir): + all_npz = glob(os.path.join(input_dir, "*.npz")) + if not all_npz: + return [] + return sorted({re.sub(r"_\d+\.npz$", "", os.path.basename(f)) for f in all_npz}) + + +def process_directory(input_dir, output_zip, mol_images_dir, molecules_csv_path, + blur_sigma, gamma, hotspot_p, alpha): + base_names = collect_base_names(input_dir) + if not base_names: + print(f" ! No NPZ files found: {input_dir}") + return False + + smiles_lookup = load_smiles_lookup(molecules_csv_path) + os.makedirs(os.path.dirname(os.path.abspath(output_zip)) or ".", exist_ok=True) + + print(f" -> {len(base_names)} molecules to process -> {output_zip}") + written = 0 + with zipfile.ZipFile(output_zip, "w", zipfile.ZIP_DEFLATED) as zf: + for i, name in enumerate(base_names, 1): + data = load_and_aggregate_data(name, input_dir) + if data is None: + print(f" [{i}/{len(base_names)}] {name}: no data, skipped") + continue + mol_rgba = load_transparent_mol_image(mol_images_dir, name, smiles_lookup) + fig = create_download_plot(data, blur_sigma, gamma, hotspot_p, alpha, mol_rgba) + img_buf = BytesIO() + fig.savefig(img_buf, format="png", dpi=300, bbox_inches="tight", + pad_inches=0, facecolor="white") + plt.close(fig) + img_buf.seek(0) + zf.writestr(f"{name}_attention.png", img_buf.getvalue()) + written += 1 + print(f" [{i}/{len(base_names)}] {name}: ok") + print(f" ✓ {written} images written -> {output_zip}") + return written > 0 + + +def find_model_subdirs(parent): + """Finds subfolders under parent that contain .npz files.""" + subs = [] + for entry in sorted(os.listdir(parent)): + full = os.path.join(parent, entry) + if os.path.isdir(full) and glob(os.path.join(full, "*.npz")): + subs.append(full) + return subs + + +def main(): + p = argparse.ArgumentParser( + description="Terminal version of the app7.py Download-ZIP flow", + formatter_class=argparse.ArgumentDefaultsHelpFormatter, + ) + p.add_argument("--input", "-i", required=True, + help="NPZ folder, or a parent folder containing model subfolders") + p.add_argument("--output_dir", "-o", default="./output_zips", + help="Folder where the ZIPs will be written") + p.add_argument("--mol_images_dir", default="./mol_images", + help="Transparent molecule PNGs") + p.add_argument("--molecules_csv", default="./molecules.csv", + help="SMILES source for missing PNGs") + p.add_argument("--hotspot_p", type=float, default=95.0, + help="Focus threshold (Top %%) — Streamlit slider") + p.add_argument("--blur_sigma", type=float, default=5.7, + help="Smoothing (Blur sigma)") + p.add_argument("--gamma", type=float, default=1.50, + help="Intensity (Gamma)") + p.add_argument("--alpha", type=float, default=0.75, + help="Opacity") + args = p.parse_args() + + if not os.path.isdir(args.input): + print(f"ERROR: input folder does not exist: {args.input}", file=sys.stderr) + sys.exit(1) + + os.makedirs(args.output_dir, exist_ok=True) + + direct_npz = bool(glob(os.path.join(args.input, "*.npz"))) + targets = [] + if direct_npz: + targets.append(args.input) + else: + targets = find_model_subdirs(args.input) + if not targets: + print(f"ERROR: no .npz found under {args.input}", file=sys.stderr) + sys.exit(1) + + print(f"Parameters: hotspot_p={args.hotspot_p} blur={args.blur_sigma} " + f"gamma={args.gamma} alpha={args.alpha}") + print(f"Number of folders to process: {len(targets)}") + + any_ok = False + for sub in targets: + label = os.path.basename(os.path.normpath(sub)) + out_zip = os.path.join(args.output_dir, f"{label}.zip") + print(f"\n== {label} ==") + ok = process_directory( + input_dir=sub, + output_zip=out_zip, + mol_images_dir=args.mol_images_dir, + molecules_csv_path=args.molecules_csv, + blur_sigma=args.blur_sigma, + gamma=args.gamma, + hotspot_p=args.hotspot_p, + alpha=args.alpha, + ) + any_ok = any_ok or ok + + sys.exit(0 if any_ok else 2) + + +if __name__ == "__main__": + main() From 2d1323f31888d8b5b18d100578325f8017ec5c57 Mon Sep 17 00:00:00 2001 From: furkannecatiinan Date: Sun, 28 Jun 2026 20:15:51 +0300 Subject: [PATCH 4/4] tsne umap --- visualisation/README.md | 51 ++++++ visualisation/extract_embeddings.py | 262 ++++++++++++++++++++++++++++ visualisation/plot_projections.py | 236 +++++++++++++++++++++++++ visualisation/run_all.py | 93 ++++++++++ 4 files changed, 642 insertions(+) create mode 100644 visualisation/README.md create mode 100644 visualisation/extract_embeddings.py create mode 100644 visualisation/plot_projections.py create mode 100644 visualisation/run_all.py diff --git a/visualisation/README.md b/visualisation/README.md new file mode 100644 index 0000000..1b882cb --- /dev/null +++ b/visualisation/README.md @@ -0,0 +1,51 @@ +# Embedding Visualisation (t-SNE / UMAP) + +Reproduces the reference figure: 2D projections of model embeddings, colored by +activity (**red = active, blue = inactive**), for each task and each architecture +(CNNModel1/CNNModel2, ViT, YOLOv11). + +Two stages, mirroring `shap/`: + +1. **`extract_embeddings.py`** — runs a trained `.pth` over a data split and saves + the **penultimate feature vector** (the input to the final classifier `Linear`, + captured with a `forward_pre_hook`) per compound → `.npz`. +2. **`plot_projections.py`** — projects those embeddings with t-SNE / UMAP and + draws per-task panels and multi-task grids. + +`run_all.py` chains both over a whole folder of models. + +## Why the penultimate layer +The final classifier reads a fixed-width feature vector; that vector is the +learned representation. Projecting it shows whether the model separates +active/inactive — exactly what the reference t-SNE panels visualise. The hook +target per architecture: + +| Architecture | Final Linear (hooked) | Embedding dim | +|---------------------|-----------------------|---------------| +| CNNModel1/CNNModel2 | `fc3` | `fc2` | +| ViT (Swinv2) | `vit.classifier` | `hidden_size` | +| YOLOv11 | `model[-1].linear/fc` | head in-feat | + +## Quick start (on the training server) + +```bash +# One model: +python visualisation/extract_embeddings.py \ + --model_path trained_models/.../CHEMBL4282-...-CNNModel1-512-256-...-state_dict.pth \ + --task CHEMBL4282 --split test --cuda 0 + +# Everything under trained_models/, then build grids + panels: +python visualisation/run_all.py --models_dir trained_models --split test --cuda 0 +``` + +Outputs: +- embeddings → `visualisation/embeddings/____.npz` +- figures → `visualisation/figures/` (`grid____.png`, per-task panels) + +## Notes +- Architecture and CNN `fc1/fc2` are auto-parsed from the canonical filename; + override with `--model_name/--fc1/--fc2` if a name is non-standard. +- ViT params are read from `config/config.yaml`. +- t-SNE perplexity / UMAP n_neighbors auto-scale to small test splits. +- UMAP needs `pip install umap-learn`; t-SNE uses scikit-learn (already a dep). +- Embeddings use the **test** split by default (held-out, matches the ROC panels). diff --git a/visualisation/extract_embeddings.py b/visualisation/extract_embeddings.py new file mode 100644 index 0000000..97fc334 --- /dev/null +++ b/visualisation/extract_embeddings.py @@ -0,0 +1,262 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Embedding Extractor +=================== +Loads a trained DEEPScreen model (CNNModel1/CNNModel2/ViT/YOLOv11) and runs a +forward pass over a data split, capturing the penultimate feature vector (the +input to the final classifier Linear) for every compound. These embeddings are +what t-SNE / UMAP are computed on in plot_projections.py. + +It does NOT visualise; it only writes raw data to a .npz, mirroring the two-stage +design of the shap/ tooling. + +Output .npz keys +---------------- +- 'embeddings' : (N, D) float32 penultimate features +- 'labels' : (N,) int64 1 = active, 0 = inactive (-1 if unknown) +- 'preds' : (N,) int64 argmax predicted class +- 'probs' : (N, 2) float32 softmax probabilities +- 'comp_ids' : (N,) str compound ids +- 'task' : str target / task id +- 'model_name' : str architecture +- 'split' : str split used + +Example +------- +python visualisation/extract_embeddings.py \ + --model_path trained_models/.../TASK-CNNModel1-512-256-...-state_dict.pth \ + --task CHEMBL4282 \ + --split test \ + --out_dir visualisation/embeddings \ + --cuda 0 + +The architecture and (for CNN) fc1/fc2 are auto-parsed from the filename when +present; override with --model_name / --fc1 / --fc2 if needed. +""" + +import os +import re +import sys +import json +import argparse +import warnings + +import numpy as np +import torch +from torch.utils.data import DataLoader + +# Make project root importable when run from anywhere +PROJECT_ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +if PROJECT_ROOT not in sys.path: + sys.path.insert(0, PROJECT_ROOT) + +from models import CNNModel1, CNNModel2, ViT, YOLOv11Classifier # noqa: E402 +from data_processing import DEEPScreenDataset # noqa: E402 + +warnings.filterwarnings("ignore") + +DEFAULT_DATASETS = os.path.join( + PROJECT_ROOT, "training_files", "target_training_datasets" +) + + +# ----------------------------------------------------------------------------- +# Filename parsing +# ----------------------------------------------------------------------------- +def parse_model_filename(path): + """ + Best-effort parse of the canonical filename: + TASK_best_val-TASK--------... + Returns dict with any of: model_name, fc1, fc2 that could be inferred. + """ + name = os.path.basename(path) + info = {} + for key in ("CNNModel1", "CNNModel2", "ViT", "YOLOv11"): + if key in name: + info["model_name"] = key + break + # fc1/fc2 follow the model name: ---- + m = re.search(r"-(CNNModel[12])-(\d+)-(\d+)-", name) + if m: + info["fc1"] = int(m.group(2)) + info["fc2"] = int(m.group(3)) + return info + + +# ----------------------------------------------------------------------------- +# Model construction +# ----------------------------------------------------------------------------- +def build_model(model_name, fc1, fc2, dropout, vit_cfg): + if model_name == "CNNModel1": + return CNNModel1(fc1, fc2, dropout) + if model_name == "CNNModel2": + return CNNModel2(fc1, fc2, dropout) + if model_name == "ViT": + c = vit_cfg + return ViT( + c["window_size"], c["hidden_size"], c["attention_probs_dropout_prob"], + c["drop_path_rate"], dropout, c["layer_norm_eps"], c["encoder_stride"], + c["embed_dim"], c["depths"], c["mlp_ratio"], num_classes=2, + ) + if model_name == "YOLOv11": + return YOLOv11Classifier(num_classes=2, model_size="yolo11m") + raise ValueError(f"Unknown model_name: {model_name}") + + +def find_final_linear(model, model_name): + """ + Return the final classifier nn.Linear whose *input* is the embedding we want. + """ + if model_name in ("CNNModel1", "CNNModel2"): + return model.fc3 + if model_name == "ViT": + # Swinv2ForImageClassification.classifier is a Linear (or Identity if 0 labels) + clf = model.vit.classifier + if isinstance(clf, torch.nn.Linear): + return clf + # Fallback: last Linear in the module tree + if model_name == "YOLOv11": + head = model.model.model[-1] + if hasattr(head, "linear") and isinstance(head.linear, torch.nn.Linear): + return head.linear + if hasattr(head, "fc") and isinstance(head.fc, torch.nn.Linear): + return head.fc + # Generic fallback: deepest-registered Linear + last = None + for m in model.modules(): + if isinstance(m, torch.nn.Linear): + last = m + if last is None: + raise RuntimeError("Could not locate a final Linear layer for hooking.") + return last + + +# ----------------------------------------------------------------------------- +# Extraction +# ----------------------------------------------------------------------------- +def extract(model_path, task, split, datasets_path, model_name, fc1, fc2, + dropout, vit_cfg, batch_size, device): + print("=" * 70) + print(f"Task={task} split={split} model={model_name}") + print(f"Model file: {model_path}") + print(f"Device: {device}") + print("=" * 70, flush=True) + + model = build_model(model_name, fc1, fc2, dropout, vit_cfg).to(device) + state = torch.load(model_path, map_location=device) + # tolerate {'state_dict': ...} wrappers + if isinstance(state, dict) and "state_dict" in state and not any( + k.startswith(("conv", "fc", "vit", "model", "bn")) for k in state + ): + state = state["state_dict"] + missing, unexpected = model.load_state_dict(state, strict=False) + if missing: + print(f" [warn] missing keys: {len(missing)} (e.g. {missing[:3]})") + if unexpected: + print(f" [warn] unexpected keys: {len(unexpected)} (e.g. {unexpected[:3]})") + model.eval() + + final_linear = find_final_linear(model, model_name) + + captured = {} + + def pre_hook(module, inputs): + # inputs[0]: (B, D) features feeding the classifier + captured["feat"] = inputs[0].detach() + + handle = final_linear.register_forward_pre_hook(pre_hook) + + dataset = DEEPScreenDataset(task, split, parent_path=datasets_path) + loader = DataLoader(dataset, batch_size=batch_size, shuffle=False, num_workers=4) + print(f"Loaded {len(dataset)} samples.", flush=True) + + emb_list, lbl_list, pred_list, prob_list, id_list = [], [], [], [], [] + + with torch.no_grad(): + for bi, (imgs, labels, comp_ids) in enumerate(loader, 1): + imgs = imgs.float().to(device) + logits = model(imgs) + if isinstance(logits, (tuple, list)): + logits = logits[0] + probs = torch.softmax(logits, dim=1) + preds = torch.argmax(probs, dim=1) + + feat = captured.get("feat") + if feat is None: + raise RuntimeError("Hook did not capture features.") + feat = feat.reshape(feat.size(0), -1) + + emb_list.append(feat.cpu().numpy().astype(np.float32)) + lbl_list.append(np.asarray(labels).astype(np.int64)) + pred_list.append(preds.cpu().numpy().astype(np.int64)) + prob_list.append(probs.cpu().numpy().astype(np.float32)) + id_list.extend(list(comp_ids)) + + if bi % 10 == 0 or bi == len(loader): + print(f" batch {bi}/{len(loader)}", flush=True) + + handle.remove() + + embeddings = np.concatenate(emb_list, axis=0) + labels = np.concatenate(lbl_list, axis=0) + preds = np.concatenate(pred_list, axis=0) + probs = np.concatenate(prob_list, axis=0) + comp_ids = np.asarray(id_list) + + print(f"Embeddings: {embeddings.shape} | active={int((labels==1).sum())} " + f"inactive={int((labels==0).sum())}", flush=True) + return dict( + embeddings=embeddings, labels=labels, preds=preds, probs=probs, + comp_ids=comp_ids, task=task, model_name=model_name, split=split, + ) + + +def main(): + ap = argparse.ArgumentParser(description="Extract penultimate embeddings.") + ap.add_argument("--model_path", required=True) + ap.add_argument("--task", required=True, help="Target/task folder id, e.g. CHEMBL4282") + ap.add_argument("--split", default="test", + choices=["test", "training", "validation", "all"]) + ap.add_argument("--datasets_path", default=DEFAULT_DATASETS) + ap.add_argument("--out_dir", default=os.path.join(PROJECT_ROOT, "visualisation", "embeddings")) + ap.add_argument("--model_name", default=None, + choices=[None, "CNNModel1", "CNNModel2", "ViT", "YOLOv11"]) + ap.add_argument("--fc1", type=int, default=None) + ap.add_argument("--fc2", type=int, default=None) + ap.add_argument("--dropout", type=float, default=0.2) + ap.add_argument("--batch_size", type=int, default=64) + ap.add_argument("--cuda", type=int, default=0) + ap.add_argument("--config", default=os.path.join(PROJECT_ROOT, "config", "config.yaml"), + help="YAML with ViT params (used only for ViT).") + args = ap.parse_args() + + parsed = parse_model_filename(args.model_path) + model_name = args.model_name or parsed.get("model_name") + if model_name is None: + raise SystemExit("Could not infer --model_name from filename; pass it explicitly.") + fc1 = args.fc1 if args.fc1 is not None else parsed.get("fc1", 128) + fc2 = args.fc2 if args.fc2 is not None else parsed.get("fc2", 256) + + vit_cfg = {} + if model_name == "ViT": + import yaml + with open(args.config) as f: + params = yaml.safe_load(f)["parameters"] + vit_cfg = params + + device = (f"cuda:{args.cuda}" if torch.cuda.is_available() else "cpu") + + result = extract( + args.model_path, args.task, args.split, args.datasets_path, + model_name, fc1, fc2, args.dropout, vit_cfg, args.batch_size, device, + ) + + os.makedirs(args.out_dir, exist_ok=True) + out_path = os.path.join(args.out_dir, f"{args.task}__{model_name}__{args.split}.npz") + np.savez_compressed(out_path, **result) + print(f"Saved -> {out_path}") + + +if __name__ == "__main__": + main() diff --git a/visualisation/plot_projections.py b/visualisation/plot_projections.py new file mode 100644 index 0000000..91c9ba2 --- /dev/null +++ b/visualisation/plot_projections.py @@ -0,0 +1,236 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Projection Plotter +================== +Loads embedding .npz files written by extract_embeddings.py, projects them to 2D +with t-SNE and/or UMAP, and renders scatter plots colored by activity +(red = active, blue = inactive) in the style of the reference figure. + +Two layouts: + * single : one figure per (task, model) .npz + * grid : a multi-panel grid across tasks for a given model + method + (rows auto-flow), mirroring the reference multi-task figure. + +Examples +-------- +# All npz in a folder, both methods, individual panels: +python visualisation/plot_projections.py single \ + --inputs visualisation/embeddings \ + --methods tsne umap \ + --out_dir visualisation/figures + +# Grid across all CNN tasks, t-SNE: +python visualisation/plot_projections.py grid \ + --inputs visualisation/embeddings \ + --model_name CNNModel1 --method tsne \ + --ncols 4 --out_dir visualisation/figures +""" + +import os +import sys +import glob +import argparse +import warnings + +import numpy as np +import matplotlib +matplotlib.use("Agg") +import matplotlib.pyplot as plt +from matplotlib.lines import Line2D + +warnings.filterwarnings("ignore") + +ACTIVE_COLOR = "#d62728" # red +INACTIVE_COLOR = "#1f77b4" # blue + + +# ----------------------------------------------------------------------------- +# Loading +# ----------------------------------------------------------------------------- +def collect_npz(inputs): + paths = [] + for inp in inputs: + if os.path.isdir(inp): + paths.extend(sorted(glob.glob(os.path.join(inp, "*.npz")))) + elif inp.endswith(".npz"): + paths.append(inp) + return paths + + +def load_npz(path): + d = np.load(path, allow_pickle=True) + return dict( + embeddings=d["embeddings"], + labels=d["labels"].astype(int), + task=str(d["task"]), + model_name=str(d["model_name"]), + split=str(d["split"]), + ) + + +# ----------------------------------------------------------------------------- +# Projection +# ----------------------------------------------------------------------------- +def project(embeddings, method, seed=42): + n = embeddings.shape[0] + if n < 3: + raise ValueError(f"Too few points to project: {n}") + + # Standardize features for stable projection + X = embeddings.astype(np.float64) + X = (X - X.mean(0)) / (X.std(0) + 1e-8) + + if method == "tsne": + from sklearn.manifold import TSNE + perplexity = max(5, min(30, (n - 1) // 3)) + return TSNE( + n_components=2, perplexity=perplexity, init="pca", + learning_rate="auto", random_state=seed, + ).fit_transform(X) + + if method == "umap": + try: + import umap + except ImportError as e: + raise SystemExit( + "umap-learn not installed. `pip install umap-learn` or use --methods tsne." + ) from e + n_neighbors = max(5, min(15, n - 1)) + return umap.UMAP( + n_components=2, n_neighbors=n_neighbors, min_dist=0.1, + random_state=seed, + ).fit_transform(X) + + raise ValueError(f"Unknown method: {method}") + + +# ----------------------------------------------------------------------------- +# Drawing +# ----------------------------------------------------------------------------- +def draw_scatter(ax, coords, labels, title=None, point_size=8): + act = labels == 1 + inact = labels == 0 + ax.scatter(coords[inact, 0], coords[inact, 1], s=point_size, + c=INACTIVE_COLOR, alpha=0.55, linewidths=0, label="Inactive") + ax.scatter(coords[act, 0], coords[act, 1], s=point_size, + c=ACTIVE_COLOR, alpha=0.65, linewidths=0, label="Active") + ax.set_xticks([]) + ax.set_yticks([]) + for spine in ax.spines.values(): + spine.set_linewidth(0.8) + if title: + ax.set_title(title, fontsize=11) + + +def legend_handles(): + return [ + Line2D([0], [0], marker="o", color="w", markerfacecolor=ACTIVE_COLOR, + markersize=8, label="Active"), + Line2D([0], [0], marker="o", color="w", markerfacecolor=INACTIVE_COLOR, + markersize=8, label="Inactive"), + ] + + +# ----------------------------------------------------------------------------- +# Modes +# ----------------------------------------------------------------------------- +def run_single(args): + paths = collect_npz(args.inputs) + if not paths: + raise SystemExit("No .npz inputs found.") + os.makedirs(args.out_dir, exist_ok=True) + + for p in paths: + data = load_npz(p) + for method in args.methods: + try: + coords = project(data["embeddings"], method, args.seed) + except Exception as e: + print(f"[skip] {os.path.basename(p)} ({method}): {e}") + continue + fig, ax = plt.subplots(figsize=(5, 5)) + title = f"{data['task']} · {data['model_name']} · {method.upper()}" + draw_scatter(ax, coords, data["labels"], title=title) + ax.legend(handles=legend_handles(), loc="best", frameon=False, fontsize=9) + fig.tight_layout() + out = os.path.join( + args.out_dir, + f"{data['task']}__{data['model_name']}__{data['split']}__{method}.png", + ) + fig.savefig(out, dpi=200, bbox_inches="tight") + plt.close(fig) + print(f"Saved -> {out}") + + +def run_grid(args): + paths = collect_npz(args.inputs) + # filter by model_name if requested + items = [] + for p in paths: + d = load_npz(p) + if args.model_name and d["model_name"] != args.model_name: + continue + items.append(d) + if not items: + raise SystemExit("No matching .npz for grid.") + + items.sort(key=lambda d: d["task"]) + n = len(items) + ncols = args.ncols + nrows = (n + ncols - 1) // ncols + + fig, axes = plt.subplots(nrows, ncols, figsize=(3.2 * ncols, 3.2 * nrows)) + axes = np.atleast_1d(axes).ravel() + + for ax, d in zip(axes, items): + try: + coords = project(d["embeddings"], args.method, args.seed) + draw_scatter(ax, coords, d["labels"], title=d["task"], point_size=6) + except Exception as e: + ax.set_title(f"{d['task']}\n(skip: {e})", fontsize=8) + ax.set_xticks([]); ax.set_yticks([]) + for ax in axes[n:]: + ax.axis("off") + + fig.legend(handles=legend_handles(), loc="lower right", frameon=False, fontsize=11) + model_tag = args.model_name or "all" + fig.suptitle(f"{args.method.upper()} embeddings — {model_tag}", fontsize=14) + fig.tight_layout(rect=[0, 0, 1, 0.97]) + + os.makedirs(args.out_dir, exist_ok=True) + out = os.path.join(args.out_dir, f"grid__{model_tag}__{args.method}.png") + fig.savefig(out, dpi=200, bbox_inches="tight") + plt.close(fig) + print(f"Saved -> {out}") + + +def main(): + ap = argparse.ArgumentParser(description="Plot t-SNE/UMAP projections.") + sub = ap.add_subparsers(dest="cmd", required=True) + + common = argparse.ArgumentParser(add_help=False) + common.add_argument("--inputs", nargs="+", required=True, + help="Folders and/or .npz files.") + common.add_argument("--out_dir", default=os.path.join( + os.path.dirname(os.path.abspath(__file__)), "figures")) + common.add_argument("--seed", type=int, default=42) + + s = sub.add_parser("single", parents=[common]) + s.add_argument("--methods", nargs="+", default=["tsne", "umap"], + choices=["tsne", "umap"]) + + g = sub.add_parser("grid", parents=[common]) + g.add_argument("--model_name", default=None) + g.add_argument("--method", default="tsne", choices=["tsne", "umap"]) + g.add_argument("--ncols", type=int, default=4) + + args = ap.parse_args() + if args.cmd == "single": + run_single(args) + else: + run_grid(args) + + +if __name__ == "__main__": + main() diff --git a/visualisation/run_all.py b/visualisation/run_all.py new file mode 100644 index 0000000..afe6e46 --- /dev/null +++ b/visualisation/run_all.py @@ -0,0 +1,93 @@ +#!/usr/bin/env python +# -*- coding: utf-8 -*- +""" +Batch driver: extract embeddings for every trained model under a directory, then +(optionally) plot per-model grids. Task id and architecture are auto-parsed from +each filename (canonical: TASK_best_val-TASK----...). + +Example +------- +python visualisation/run_all.py \ + --models_dir trained_models \ + --split test \ + --emb_dir visualisation/embeddings \ + --fig_dir visualisation/figures \ + --methods tsne umap \ + --cuda 0 +""" +import os +import re +import sys +import glob +import argparse +import subprocess + +HERE = os.path.dirname(os.path.abspath(__file__)) + + +def parse_task_and_model(path): + name = os.path.basename(path) + model_name = next( + (k for k in ("CNNModel1", "CNNModel2", "ViT", "YOLOv11") if k in name), None + ) + # Canonical: "_best_val---..." + m = re.match(r"(.+?)_best_val-([^-]+)-", name) + task = m.group(2) if m else name.split("-")[0] + return task, model_name + + +def main(): + ap = argparse.ArgumentParser() + ap.add_argument("--models_dir", required=True) + ap.add_argument("--split", default="test") + ap.add_argument("--emb_dir", default=os.path.join(HERE, "embeddings")) + ap.add_argument("--fig_dir", default=os.path.join(HERE, "figures")) + ap.add_argument("--methods", nargs="+", default=["tsne", "umap"]) + ap.add_argument("--cuda", type=int, default=0) + ap.add_argument("--skip_plots", action="store_true") + args = ap.parse_args() + + pths = sorted(glob.glob(os.path.join(args.models_dir, "**", "*.pth"), recursive=True)) + if not pths: + raise SystemExit(f"No .pth under {args.models_dir}") + + print(f"Found {len(pths)} model files.") + for p in pths: + task, model_name = parse_task_and_model(p) + if model_name is None: + print(f"[skip] cannot infer architecture: {os.path.basename(p)}") + continue + print(f"\n>>> {task} / {model_name}") + cmd = [ + sys.executable, os.path.join(HERE, "extract_embeddings.py"), + "--model_path", p, "--task", task, "--split", args.split, + "--out_dir", args.emb_dir, "--cuda", str(args.cuda), + "--model_name", model_name, + ] + rc = subprocess.call(cmd) + if rc != 0: + print(f"[warn] extraction failed (rc={rc}) for {p}") + + if args.skip_plots: + return + + # Per-architecture grids + for model_name in ("CNNModel1", "CNNModel2", "ViT", "YOLOv11"): + for method in args.methods: + cmd = [ + sys.executable, os.path.join(HERE, "plot_projections.py"), "grid", + "--inputs", args.emb_dir, "--model_name", model_name, + "--method", method, "--out_dir", args.fig_dir, + ] + subprocess.call(cmd) + # Individual panels for everything + cmd = [ + sys.executable, os.path.join(HERE, "plot_projections.py"), "single", + "--inputs", args.emb_dir, "--methods", *args.methods, + "--out_dir", args.fig_dir, + ] + subprocess.call(cmd) + + +if __name__ == "__main__": + main()