<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/tr/design.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#TrDesign in JAX!
Backprop through TrRosetta for protein design.

In [None]:
#@title install
%%bash
if [ ! -d params/tr ]; then
  # get code
  pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1

  # for debugging
  ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign

  # download params
  mkdir -p params/tr
  wget -qnc https://files.ipd.uw.edu/krypton/TrRosetta/models.zip -P params/tr/
  wget -qnc https://files.ipd.uw.edu/krypton/TrRosetta/bkgr_models.zip -P params/tr/
  unzip -qqo params/tr/models.zip -d params/tr/
  unzip -qqo params/tr/bkgr_models.zip -d params/tr/
  rm params/tr/models.zip 
  rm params/tr/bkgr_models.zip

fi

In [None]:
#@title import libraries
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
from colabdesign import *
from google.colab import files
from IPython.display import HTML
import os
import numpy as np

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  else:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"

#Hallucination

In [None]:
clear_mem()
tr_model = mk_trdesign_model(protocol="hallucination")
tr_model.prep_inputs(length=100)

In [None]:
tr_model.restart()
tr_model.set_opt(hard=False)
tr_model.design(50, verbose=10, save_best=False)
tr_model.set_opt(hard=True)
tr_model.design(50, verbose=10, save_best=True)

In [None]:
print(tr_model.get_loss())
print(tr_model.get_seq())
tr_model.plot("preds")

#fixbb

In [None]:
clear_mem()
tr_model = mk_trdesign_model(protocol="fixbb")
tr_model.prep_inputs(get_pdb("1TEN"),chain="A")

In [None]:
tr_model.restart()
tr_model.set_opt(hard=False)
tr_model.design(50, verbose=10, save_best=False)
tr_model.set_opt(hard=True)
tr_model.design(50, verbose=10, save_best=True)

In [None]:
print(tr_model.get_loss())
print(tr_model.get_seq())
tr_model.plot("preds")

# combine with AfDesign

In [None]:
%%bash
if [ ! -d params/af ]; then
  # download alphafold weights
  mkdir -p params/af/params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-03-02.tar | tar x -C params/af/params
fi

### initialize with trdesign sequence

In [None]:
af_model = mk_afdesign_model(protocol="fixbb",data_dir="params/af")
af_model.prep_inputs(get_pdb("1TEN"))

In [None]:
af_model.restart(seq=tr_model.get_seq())
af_model.design_3stage(100,100,10)

In [None]:
af_model.plot_traj()

In [None]:
af_model.plot_pdb()

### let's try a joint optimization

In [None]:
af_model.restart()
af_model.design_3stage(100,100,10, callback=tr_model.af_callback())

In [None]:
af_model.plot_traj()

In [None]:
af_model.plot_pdb()

In [None]:
HTML(af_model.animate())

In [None]:
af_model.get_seqs()