Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .actlignore
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
.pixi/
grid_search_results/
outputs/
data/
/data/
initial_dataset_40*/
checkpoints/
release_data/
Expand Down
17,255 changes: 14,104 additions & 3,151 deletions pixi.lock

Large diffs are not rendered by default.

70 changes: 63 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -15,12 +15,58 @@ analysis = [
"scikit-learn",
"seaborn"
]
boltz = ["boltz", "cuequivariance-torch", "cuequivariance-ops-torch-cu12", "rdkit>=2025.3.6"]
dev = ["pytest", "pytest-cov", "mypy", "prek", "ty", "ruff", "pytest-loguru", "python-semantic-release"]
protenix = ["protenix>=0.6.3", "einx", "triton"]

[tool.pixi.feature.boltz]
pypi-dependencies = {"boltz" = "*", "cuequivariance-torch" = "*", "cuequivariance-ops-torch-cu12" = "*", "rdkit" = ">=2025.3.6"}
platforms = ["linux-64"]

[tool.pixi.feature.boltz-osx]
pypi-dependencies = {"boltz" = "*", "cuequivariance-torch" = "*", "rdkit" = ">=2025.3.6", "joblib" = "*", "sfcalculator-torch" = ">=0.3.2"}
platforms = ["osx-arm64"]


[tool.pixi.feature.protenix]
pypi-dependencies = {"protenix" = "*", "einx" = "*", "triton" = "*"}
platforms = ["linux-64"]

[tool.pixi.feature.protpardelle]
platforms = ["linux-64", "osx-arm64"]

[tool.pixi.feature.protpardelle.dependencies]
biopython = "*"
einops = "*"
huggingface_hub = "*"
hydra-core = "*"
jupyter = "*"
numpy = ">=1.25.0,<2.0"
omegaconf = "*"
pandas = "==2.3.1"
prody = "*"
pyyaml = "*"
scipy = "*"
pytorch = ">=2.6.0,<2.8"
tqdm = "*"
typer = "*"

[tool.pixi.feature.protpardelle.pypi-dependencies]
protpardelle = {git = "https://github.com/ProteinDesignLab/protpardelle-1c.git"}
biotite = "*"
dm-tree = "*"
jaxtyping = "*"
ml_collections = "*"
modelcif = "*"
transformers = "*"
wandb = "*"

[tool.pixi.feature.protpardelle.pypi-options.dependency-overrides]
pandas = "==2.3.1"

[project]
authors = [{email = "karson.chrispens@ucsf.edu", name = "Karson Chrispens"}]
authors = [
{email = "karson.chrispens@ucsf.edu", name = "Karson Chrispens"},
{email = "marcus.collins@astera.org", name = "Marcus D. Collins"}
]
dependencies = [
"atomworks[ml]==2.1.1",
"python-dotenv",
Expand Down Expand Up @@ -59,21 +105,26 @@ CUDA_HOME = "$CONDA_PREFIX"
PYTHONNOUSERSITE = "1"

[tool.pixi.dependencies]
cuda-toolkit = ">=12,<13"
gcc_linux-64 = ">=9,<13"
gxx_linux-64 = ">=9,<13"
ninja = "*"
numpy = "<2.0"
pyarrow = "==17.0.0"

[tool.pixi.target.linux-64.dependencies]
cuda-toolkit = ">=12,<13"
gcc_linux-64 = ">=9,<13"
gxx_linux-64 = ">=9,<13"

[tool.pixi.environments]
analysis = {features = ["analysis"]}
analysis-dev = {features = ["analysis", "dev"]}
boltz = {features = ["boltz"]}
boltz-analysis = {features = ["boltz", "analysis"]}
boltz-dev = {features = ["boltz", "dev"]}
boltz-osx = {features = ["boltz-osx", "dev", "analysis"]}
protenix = {features = ["protenix"]}
protenix-dev = {features = ["protenix", "dev"]}
protpardelle = {features = ["protpardelle"]}
protpardelle-dev = {features = ["protpardelle", "dev"]}
rf3 = {features = ["rf3"]}
rf3-dev = {features = ["rf3", "dev"]}

Expand All @@ -83,6 +134,11 @@ rdkit = ">=2024.3.5,<2025.9"
[tool.pixi.feature.rf3.pypi-dependencies]
rc-foundry = {editable = true, extras = ["rf3"], git = "https://github.com/k-chrispens/foundry.git"}

[tool.pixi.pypi-options.dependency-overrides]
pandas = "==2.3.1"
gemmi = "==0.6.7"


[tool.pixi.pypi-dependencies]
sampleworks = {editable = true, path = "."}

Expand Down Expand Up @@ -117,7 +173,7 @@ cmd = "pytest -m 'not slow' {{ flags }}"

[tool.pixi.workspace]
channels = ["conda-forge"]
platforms = ["linux-64"]
platforms = ["linux-64", "osx-arm64"]

[tool.pytest.ini_options]
addopts = "-v --strict-markers"
Expand Down
8 changes: 7 additions & 1 deletion src/sampleworks/cli/guidance.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,16 +5,22 @@
import sys

from sampleworks.utils.guidance_script_arguments import GuidanceConfig
from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance


def main(argv: list[str] | None = None) -> int:
config = GuidanceConfig.from_cli(argv)

from loguru import logger

from sampleworks.utils.guidance_script_utils import get_model_and_device, run_guidance

logger.info(f"Running guidance with config: {config}")
device, model_wrapper = get_model_and_device(
config.device,
getattr(config, "model_checkpoint", None),
config.model,
method=getattr(config, "method", None),
protpardelle_config_path=getattr(config, "protpardelle_config_path", None),
)
result = run_guidance(config, config.guidance_type, model_wrapper, device)
return result.exit_code
Expand Down
4 changes: 3 additions & 1 deletion src/sampleworks/core/samplers/edm.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,8 +422,10 @@ def step(
# t_hat will be float if check_context didn't raise
# Use no_grad when gradients aren't needed to avoid memory overhead from
# gradient checkpointing holding intermediate activations
# TODO testing adding eps to signature for use with Protpardelle-1c, if successful,
# I need to modify the Protocol itself.
with torch.set_grad_enabled(allow_gradients):
x_hat_0 = model_wrapper.step(noisy_state, t_hat, features=features)
x_hat_0 = model_wrapper.step(noisy_state, t_hat, eps, features=features)

reconciler = (
context.reconciler.to(torch.as_tensor(x_hat_0).device)
Expand Down
101 changes: 101 additions & 0 deletions src/sampleworks/data/cc89_epoch415.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
data:
auto_calc_sigma_data: true
chain_residx_gap: 200
dummy_fill_mode: zero
fixed_size: 512
mixing_ratios:
- 1.0
n_aatype_tokens: 21
n_examples_for_sigma_data: 500
pdb_paths:
- /scratch/users/tianyulu/augmented_ingraham_cath_bugfree
se3_data_augment: true
sigma_data: 10.3
subset:
- designable_
translation_scale: 1.0
diffusion:
sampling:
function: uniform
s_max: 80
s_min: 0.001
training:
function: lognormal
psigma_mean: -0.5
psigma_std: 1.5
model:
compute_loss_on_all_atoms: false
conditioning_style: concat
crop_conditional: true
dummy_fill_masked_atoms: false
full_mpnn_model_path: /scratch/users/tianyulu/farfalle/ProteinMPNN/vanilla_model_weights
mpnn_model:
label_smoothing: 0.1
n_channel: 128
n_layers: 3
n_neighbors: 32
noise_cond_mult: 4
use_self_conditioning: true
mpnn_model_checkpoint: ''
pretrained_modules: []
struct_model:
arch: dit
n_atoms: 37
n_channel: 256
noise_cond_mult: 4
uvit:
cat_pwd_to_conv: false
conv_skip_connection: false
dim_head: 32
n_blocks_per_layer: 2
n_filt_per_layer: []
n_heads: 8
n_layers: 10
patch_size: 1
position_embedding_max: 32
position_embedding_type: rotary
struct_model_checkpoint: ''
task: ai-allatom
train:
batch_size: 32
checkpoint_freq: 1
checkpoints: []
ckpt_path: /scratch/users/tianyulu/farfalle/out_dir/farfalle/cc89/checkpoints/epoch206_training_state.pth
clip_grad_norm: true
crop_cond:
contiguous_prob: 0.05
discontiguous_prob: 0.9
dist_threshold: 45.0
max_discontiguous_res: 24
max_span_len: 12
recenter_coords: true
sidechain_only_prob: 0.0
sidechain_prob: 0.9
terms_prob: 0.5
crop_conditional: true
decay_steps: 2000000
eval_freq: 8000000
eval_loss_t:
- 0.1
- 0.3
- 0.5
- 0.7
- 0.9
fpd_length_ranges_per_chain:
- - 50
- 256
grad_clip_val: 1.0
length_ranges_per_chain:
- - 166
- 188
lr: 0.0001
max_epochs: 10000
n_eval_samples: 10
n_fpd_samples: 0
sc_num_seqs: 4
seed: 0
self_cond_train_prob: 0.9
shapes_path: /scratch/users/tianyulu/protein_shapes
subsample_eval_set: 0.05
warmup_steps: 1000
weight_decay: 0.0
1 change: 1 addition & 0 deletions src/sampleworks/models/protpardelle/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
"""Protpardelle-1c model wrapper."""
Loading