<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/af/examples/disulfide_design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# disulfide-hallucination
For a given length, generate/hallucinate a disulfide-rich sequence that AlphaFold thinks folds into a well structured protein

In [None]:
#@title setup
%%time
import os
if not os.path.isdir("params"):
  # get code
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1")
  # for debugging
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")
  # download params
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
  os.system("tar -xf alphafold_params_2022-12-06.tar -C params")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
from colabdesign import mk_afdesign_model, clear_mem
from IPython.display import HTML
from google.colab import files
import numpy as np

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"

In [None]:
import random
from jax.lax import dynamic_slice
import jax.numpy as jnp

from colabdesign.af.loss import _get_con_loss
def get_con_loss(dgram, dgram_bins, cutoff=None, binary=True,
                 num=1, seqsep=0, offset=None):
  '''convert distogram into contact loss'''  
  x = _get_con_loss(dgram, dgram_bins, cutoff, binary)  
  a,b = x.shape
  if offset is None:
    mask = jnp.abs(jnp.arange(a)[:,None] - jnp.arange(b)[None,:]) >= seqsep
  else:
    mask = jnp.abs(offset) >= seqsep
  x = jnp.sort(jnp.where(mask,x,jnp.nan))
  k_mask = (jnp.arange(b) < num) * (jnp.isnan(x) == False)    
  return jnp.where(k_mask,x,0.0).sum(-1) / (k_mask.sum(-1) + 1e-8)

def generate_disulfide_pattern(L, disulfide_num, min_sep=5):
    disulfide_pattern = []
    positions = list(range(L))
    for n in range(disulfide_num):
        for _ in range(100): # try 100 time per postion.
            i,j = random.sample(positions, k=2)
            if abs(i-j)<=min_sep: continue # set min loop len.
            positions.remove(i)
            positions.remove(j)
            disulfide_pattern.append((i,j))
            # check
            if _ > 99:
                print('Not find good disulfide_pos! exit....')
                return 0  # not good pose!
            else:
                break
    sequence_pattern = list('X'*L)
    for pair in disulfide_pattern:
        for i in pair: sequence_pattern[i] = 'C'

    return disulfide_pattern, ''.join(sequence_pattern), L

def disulfide_loss(inputs, outputs):
  def get_disulfide_loss(dgram, dgram_bins, disulfide_pattern):
    '''
    Func: simple disulfide loss, make the contacts < 7.0/7.5A.
    # see: https://www.ncbi.nlm.nih.gov/pmc/articles/PMC7316719/
    params: disulfide_pattern: List[(pos1, pos2), (pos3, pos4)...]
    '''
    disulfide_loss = 0.0
    for pair in disulfide_pattern:
      i,j = pair
      pair_dgram = dynamic_slice(dgram, (i,j,0), (1,1,len(dgram_bins))) + dynamic_slice(dgram, (j,i,0), (1,1,len(dgram_bins)))
      disulfide_loss += get_con_loss(pair_dgram, dgram_bins, cutoff=7.0, binary=False, num=1)
    return disulfide_loss.mean()

  # add disulfide loss here:
  dgram_logits = outputs['distogram']['logits']
  dgram_bins = jnp.append(0, outputs["distogram"]["bin_edges"])
  return {"disulfide":get_disulfide_loss(dgram_logits, dgram_bins, inputs['opt']['disulfide_pattern'])}

0. generate_cys_pattern

In [None]:
disulfide_pattern, sequence_pattern, L = generate_disulfide_pattern(35, 3)
print(disulfide_pattern)
print(sequence_pattern)
print(L)

1. hallucination with disulfide-pattern

In [None]:
# using some disulfide sequence as input:
clear_mem()
af_model = mk_afdesign_model(protocol="hallucination", loss_callback=disulfide_loss)
af_model.opt["weights"]["disulfide"] = 1.0
af_model.prep_inputs(length=L)

print("length",af_model._len)
print("weights",af_model.opt["weights"])

In [None]:
# set disulfide_pattern sequence.
af_model.restart(seq=sequence_pattern, add_seq=True, rm_aa='C')
# set disulfide_pattern:
af_model.opt['disulfide_pattern'] = disulfide_pattern
# reweight con:
af_model.opt["weights"]['con'] = 0.5

In [None]:
af_model.design_3stage(50,50,10)

In [None]:
af_model.save_pdb(f"{af_model.protocol}.pdb")
af_model.plot_pdb(show_sidechains=True)

2. fix-sidechain with pyRosetta

In [None]:
from pyrosetta import init, pose_from_pdb, Pose, create_score_function
from pyrosetta.rosetta.protocols.relax import FastRelax
from pyrosetta.rosetta.core.pack.task import TaskFactory
from pyrosetta.rosetta.core.kinematics import MoveMap
from pyrosetta.rosetta.core.pack.task.operation import InitializeFromCommandline
from pyrosetta.rosetta.protocols.denovo_design import DisulfidizeMover
from pyrosetta.rosetta.core.pack.task.operation import RestrictToRepacking
from pyrosetta.rosetta.core.select.residue_selector import ChainSelector

def fastrelax(pose):
    full_score = create_score_function('ref2015')
    tf = TaskFactory()
    # tf.push_back(InitializeFromCommandline())
    tf.push_back(RestrictToRepacking())
    mmap = MoveMap()
    mmap.set_bb(True)
    mmap.set_chi(True)

    # FastRelax:
    fast_design = FastRelax(full_score, 4)  # repeat:6
    fast_design.set_task_factory(tf)
    fast_design.set_movemap(mmap)
    fast_design.apply(pose)
    
def build_stapled_pose(pose, cys_pattern):
    # rebuild 
    disulfidizer = DisulfidizeMover()
    disulfidizer.set_match_rt_limit(999.0)
    disulfidizer.set_max_disulf_score(999.0)
    for pair in cys_pattern:
        full_score = create_score_function('ref2015')
        stapled_pose = Pose().assign(pose)  # pose for disulfides stapled
        residue1, residue2 = pair
        disulfidizer.make_disulfide(stapled_pose, residue1+1, residue2+1, False, full_score)
        # fastdesign mover
        pdbnum1 = stapled_pose.pdb_info().pose2pdb(residue1)
        pdbnum2 = stapled_pose.pdb_info().pose2pdb(residue2)
        print('# OPTIMIZE:  Around These Residues: %s %s' % (pdbnum1, pdbnum2))

    # save pdb:
    fastrelax(stapled_pose)
    
    # filter: all stapled?
    v = ChainSelector(1).apply(stapled_pose)
    l = disulfidizer.find_current_disulfides(stapled_pose, v, v)
    print(len([i for i in l]), len(cys_pattern))
    if len([i for i in l]) == len(cys_pattern):
        return stapled_pose
    else:
        return 0
    

In [None]:
# optimze the side-chain
init('-mute all')
pose = pose_from_pdb(f"{af_model.protocol}.pdb")
stapled_pose = build_stapled_pose(pose, disulfide_pattern)