Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
28 changes: 13 additions & 15 deletions scripts/generate_esm_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -43,11 +43,15 @@

from src.constants import ONE_TO_THREE, THREE_TO_ONE
from src.dataset import parse_asu_with_biotite
from src.utils import normalize_ins_code, parse_split_file, setup_logging_for_tqdm
from src.utils import (
normalize_ins_code,
parse_split_file,
setup_logging_for_tqdm,
)


def compute_esm_embeddings(
pdb_path: Path,
struc_path: Path,
model: ESM3,
) -> dict | None:
"""
Expand All @@ -61,17 +65,17 @@ def compute_esm_embeddings(
How ESM parses: (https://github.com/evolutionaryscale/esm/blob/main/esm/utils/structure/protein_chain.py)

Args:
pdb_path: Path to PDB file
struc_path: Path to structure file (PDB/CIF)
model: Loaded ESM3 model

Returns:
Dict with 'residue_embeddings', 'sequence', 'num_residues', or None on error
"""
try:
# Load ground truth atoms using geometry cache parser in src/dataset.py
protein_atoms, _ = parse_asu_with_biotite(str(pdb_path))
protein_atoms, _ = parse_asu_with_biotite(str(struc_path))
if len(protein_atoms) == 0:
raise ValueError(f"No protein atoms found in {pdb_path}")
raise ValueError(f"No protein atoms found in {struc_path}")

# Extract ground truth sequence before mutating the array
key_to_resname = {}
Expand Down Expand Up @@ -115,7 +119,7 @@ def compute_esm_embeddings(
protein = ESMProtein.from_protein_complex(complex_obj)

if not protein.sequence or protein.sequence.replace("|", "") == "":
raise ValueError(f"ESM returned empty sequence for {pdb_path}")
raise ValueError(f"ESM returned empty sequence for {struc_path}")

with torch.no_grad():
protein_tensor = model.encode(protein)
Expand All @@ -140,7 +144,7 @@ def compute_esm_embeddings(
# Validate: length mismatch means embeddings won't align with residues
if len(esm_seq) != num_residues:
raise ValueError(
f"Length mismatch after sanitization for {pdb_path}! "
f"Length mismatch after sanitization for {struc_path}! "
f"Biotite: {num_residues}, ESM: {len(esm_seq)}"
)

Expand All @@ -151,7 +155,7 @@ def compute_esm_embeddings(
}

except Exception as e:
logger.error(f"Error computing embeddings for {pdb_path}: {e}")
logger.error(f"Error computing embeddings for {struc_path}: {e}")
return None


Expand Down Expand Up @@ -239,16 +243,10 @@ def main() -> None:
failures = []

for entry in tqdm(entries, desc="Computing ESM embeddings"):
pdb_path = entry["pdb_path"]
cache_key = entry["cache_key"]
cache_path = esm_cache_dir / f"{cache_key}.pt"

if not pdb_path.exists():
logger.error(f"PDB file not found: {pdb_path}")
failures.append((cache_key, "PDB file not found"))
continue

result = compute_esm_embeddings(pdb_path, model)
result = compute_esm_embeddings(entry["struc_path"], model)

if result is not None:
result["pdb_id"] = entry["pdb_id"]
Expand Down
8 changes: 1 addition & 7 deletions scripts/generate_slae_embeddings.py
Original file line number Diff line number Diff line change
Expand Up @@ -390,17 +390,11 @@ def main() -> None:
batch_info = []

for entry in entry_batch:
pdb_path = entry["pdb_path"]
cache_key = entry["cache_key"]

if not pdb_path.exists():
logger.error(f"PDB file not found: {pdb_path}")
failures.append((cache_key, "PDB file not found"))
continue

try:
# protein_atoms: biotite AtomArray with num_atoms atoms
protein_atoms, _ = parse_asu_with_biotite(str(pdb_path))
protein_atoms, _ = parse_asu_with_biotite(str(entry["struc_path"]))
# coords: (num_residues, 37, 3) - atom37 coordinates
# residue_type: (num_residues,) - residue type indices
# chains: (num_residues,) - chain IDs
Expand Down
108 changes: 71 additions & 37 deletions src/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
import torch
import torch.nn.functional as F
from biotite.structure.io.pdb import get_structure, PDBFile
from biotite.structure.io.pdbx import CIFFile, get_structure as get_structure_cif
from loguru import logger
from scipy.spatial.distance import cdist
from torch import Tensor
Expand All @@ -46,14 +47,28 @@ def element_onehot(symbols: list[str]) -> Tensor:
return F.one_hot(indices, num_classes=other_idx + 1).float()


def _read_structure(path: str | Path, extra_fields=None) -> bts.AtomArray:
"""Read structure from PDB or CIF file, dispatching on extension."""
path = Path(path)
kw = dict(model=1, altloc="occupancy")
if extra_fields:
kw["extra_fields"] = extra_fields
if path.suffix == ".cif":
cif_file = CIFFile.read(path)
return get_structure_cif(cif_file, **kw)
else:
pdb_file = PDBFile.read(path)
return get_structure(pdb_file, **kw)


def parse_asu_with_biotite(
path: str,
path: str | Path,
) -> tuple[bts.AtomArray, bts.AtomArray]:
Comment thread
Copilot marked this conversation as resolved.
"""
Parse PDB file and extract protein and water atoms.
Parse PDB or CIF file and extract protein and water atoms.

Args:
path: Path to PDB file
path: Path to PDB or CIF file

Returns:
Tuple of (protein_atoms, water_atoms) as biotite AtomArrays.
Expand All @@ -64,9 +79,10 @@ def parse_asu_with_biotite(
- altloc="occupancy": Selects highest-occupancy alternate conformation
- Uses filter_amino_acids (not filter_canonical_amino_acids) to include
modified residues like MSE, SEC that external encoders may handle
- b_factor is always read so the caller can compute normalized B-factors
without a second file read.
"""
pdb_file = PDBFile.read(path)
atoms = get_structure(pdb_file, model=1, altloc="occupancy")
atoms = _read_structure(path, extra_fields=["b_factor"])

atoms = atoms[atoms.element != "H"]

Expand All @@ -80,7 +96,7 @@ def parse_asu_with_biotite(


def get_crystal_contacts_pymol(
pdb_path: str, cutoff: float = 5.0
struc_path: str, cutoff: float = 5.0
) -> dict[str, np.ndarray | list]:
"""
Extract ASU and symmetry mate atoms within crystal contact distance.
Expand All @@ -89,7 +105,7 @@ def get_crystal_contacts_pymol(
interface atoms within the specified cutoff distance.

Args:
pdb_path: Path to PDB file with crystal symmetry information
struc_path: Path to structure file (PDB/CIF) with crystal symmetry information
cutoff: Distance cutoff in Angstroms for interface detection

Returns:
Expand All @@ -104,7 +120,7 @@ def get_crystal_contacts_pymol(
cmd.reinitialize()
cmd.feedback("disable", "all", "everything")
obj = "struct"
cmd.load(pdb_path, obj)
cmd.load(struc_path, obj)
cmd.symexp("sym", obj, obj, cutoff)
cmd.select("interface", f"byres (sym* within {cutoff} of {obj})")

Expand Down Expand Up @@ -492,7 +508,7 @@ def load_edia_for_pdb(


def compute_normalized_bfactors(
pdb_path: str,
struc_path: str,
) -> tuple[dict[tuple[str, int, str], float] | None, np.ndarray | None]:
"""
Extract and normalize B-factors for water molecules.
Expand All @@ -501,7 +517,7 @@ def compute_normalized_bfactors(
in the selected structure.

Args:
pdb_path: Path to PDB file
struc_path: Path to structure file (PDB/CIF)

Returns:
Tuple of:
Expand All @@ -510,34 +526,38 @@ def compute_normalized_bfactors(
Returns (None, None) on error
"""
try:
pdb_file = PDBFile.read(pdb_path)
atoms = pdb_file.get_structure(
model=1, altloc="occupancy", extra_fields=["b_factor"]
)
atoms = _read_structure(struc_path, extra_fields=["b_factor"])

# filter for water molecules
water_mask = (atoms.res_name == "HOH") | (atoms.res_name == "WAT")
water_atoms = atoms[water_mask]

return _compute_normalized_bfactors_from_atoms(water_atoms)

except Exception as e:
logger.warning(f"Warning: Could not extract B-factors from {struc_path}: {e}")
return None, None


def _compute_normalized_bfactors_from_atoms(
water_atoms: bts.AtomArray,
) -> tuple[dict[tuple[str, int, str], float] | None, np.ndarray | None]:
"""Compute normalized B-factors from an already-parsed water AtomArray."""
try:
if not water_atoms:
return None, None

# Normalize using water-only B-factor statistics.
water_mean = np.mean(water_atoms.b_factor)
water_std = np.std(water_atoms.b_factor)

# lookup dictionary with one entry per unique water residue
bfactor_lookup = {}

for i in range(len(water_atoms)):
chain_id = str(water_atoms.chain_id[i])
res_id = int(water_atoms.res_id[i])
ins_code = normalize_ins_code(water_atoms.ins_code[i])
key = (chain_id, res_id, ins_code)

if key not in bfactor_lookup:
raw_bfactor = water_atoms.b_factor[i]
# If all water B-factors are identical, assign neutral z-score 0.0.
normalized = (
(raw_bfactor - water_mean) / max(water_std, 1e-3)
if water_std > 0
Expand All @@ -548,7 +568,7 @@ def compute_normalized_bfactors(
return bfactor_lookup, water_atoms.b_factor

except Exception as e:
logger.warning(f"Warning: Could not extract B-factors from {pdb_path}: {e}")
logger.warning(f"Warning: Could not compute B-factors from atoms: {e}")
return None, None


Expand Down Expand Up @@ -806,34 +826,47 @@ def _parse_pdb_list(self, pdb_list_file: str) -> list[dict]:
Expected format:
<pdb_id>_final (e.g., "6eey_final")

Constructs path: {base_pdb_dir}/{pdb_id}/{pdb_id}_final.pdb
Resolves path in {base_pdb_dir}/{pdb_id}/, preferring
{pdb_id}_final.cif when present, otherwise falling back to
{pdb_id}_final.pdb.
"""
entries = []
logger.info(f"Parsing PDB list: {pdb_list_file}")
pdb_ids = []
with open(pdb_list_file, "r") as f:
for line in f:
line = line.strip()
if not line:
continue

if not line.endswith("_final"):
logger.warning(f"Warning: Unexpected format: {line}")
continue
pdb_id = line.removesuffix("_final")
if not pdb_id:
logger.warning(f"Warning: Unexpected format: {line}")
continue
pdb_ids.append((pdb_id, line))

pdb_path = self.base_pdb_dir / pdb_id / f"{pdb_id}_final.pdb"
logger.info(
f"Read {len(pdb_ids)} IDs, resolving file paths for requested entries..."
)

# Cache key is just the base key - directory separation handles mates
entries.append(
{
"pdb_id": pdb_id,
"pdb_path": pdb_path,
"cache_key": line,
"embedding_key": line, # Same as cache_key for embedding lookup
}
)
for pdb_id, cache_key in pdb_ids:
subdir = self.base_pdb_dir / pdb_id
cif_path = subdir / f"{pdb_id}_final.cif"
struc_path = (
cif_path if cif_path.is_file() else subdir / f"{pdb_id}_final.pdb"
)
Comment thread
Copilot marked this conversation as resolved.

# Cache key is just the base key - directory separation handles mates
entries.append(
{
"pdb_id": pdb_id,
"struc_path": struc_path,
"cache_key": cache_key,
"embedding_key": cache_key, # Same as cache_key for embedding lookup
}
)

logger.info(f"Loaded {len(entries)} entries from {pdb_list_file}")
return entries
Expand Down Expand Up @@ -898,9 +931,9 @@ def _preprocess_one(self, entry: dict, cache_path: Path):

Raises ValueError if structure fails quality filters.
"""
pdb_path = str(entry["pdb_path"])
struc_path = str(entry["struc_path"])

protein_atoms, water_atoms = parse_asu_with_biotite(pdb_path)
protein_atoms, water_atoms = parse_asu_with_biotite(struc_path)

# check inter-chain interactions for multi-chain proteins
chain_valid, chain_reason, _ = check_chain_interactions(
Expand All @@ -910,7 +943,7 @@ def _preprocess_one(self, entry: dict, cache_path: Path):
if not chain_valid:
raise ValueError(f"Quality filter failed: {chain_reason}")

crystal_data = get_crystal_contacts_pymol(pdb_path, self.cutoff)
crystal_data = get_crystal_contacts_pymol(struc_path, self.cutoff)

# Ensure consistency between biotite and PyMOL parsing.
# Both parse the same ASU, but may differ in altloc selection, hydrogen
Expand All @@ -937,7 +970,7 @@ def _preprocess_one(self, entry: dict, cache_path: Path):
# load EDIA data only when the EDIA filter is active
edia_lookup = None
if use_edia_filter:
edia_json_path = Path(pdb_path).with_suffix(".json")
edia_json_path = Path(struc_path).with_suffix(".json")
edia_lookup = load_edia_for_pdb(edia_json_path)
if edia_lookup is None:
raise ValueError(
Expand All @@ -946,9 +979,10 @@ def _preprocess_one(self, entry: dict, cache_path: Path):
)

# compute normalized B-factors only when the B-factor filter is active
# water_atoms already has b_factor from parse_asu_with_biotite — no second read needed
bfactor_lookup = None
if use_bfactor_filter:
bfactor_lookup, _ = compute_normalized_bfactors(pdb_path)
bfactor_lookup, _ = _compute_normalized_bfactors_from_atoms(water_atoms)

# build water keys for filtering
water_keys = list(
Expand Down
Loading
Loading