<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/mpnn/examples/afdesign_and_proteinmpnn_bias.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#ProteinMPNN + AFDesign

Design protein sequence based on a given structure, using backprop through AlphaFold guided by ProteinMPNN output. Notebook from [@sim0nsays](https://github.com/sim0nsays)!

In [None]:
#@title install
%%bash
if [ ! -d params ]; then
  # get code
  pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1
  # for debugging
  ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign
  # download params
  mkdir params
  curl -fsSL https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar | tar x -C params
fi

In [None]:
#@title Setup

#@title import libraries
import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)
import os
os.environ['XLA_PYTHON_CLIENT_PREALLOCATE'] = 'false'

from colabdesign.af import mk_af_model, clear_mem
from colabdesign.af.alphafold.common import residue_constants
from colabdesign.mpnn import mk_mpnn_model
from colabdesign.shared.utils import copy_dict
from IPython.display import HTML
from google.colab import files

import numpy as np
from scipy.special import softmax
import matplotlib.pyplot as plt
import jax
import jax.numpy as jnp

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  else:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"


#@markdown ### Input Options
pdb='1TEN' #@param {type:"string"}
pdb_path = get_pdb(pdb)

#@markdown - pdb code (leave blank to get an upload prompt)

designed_chain = "A" #@param {type:"string"}
fixed_chains = "" #@param {type:"string"}

#@markdown - specified which chain(s) to design and which chain(s) to keep fixed. 
#@markdown   Use comma:`A,B` to specifiy more than one chain

#@markdown ### Design Options
use_templates = False #@param {type:"boolean"}
#@markdown - provides structure templates to AFDesign, helpful to maintain multi-chain contacts

#@markdown
mpnn_bias_temp = 4.0 #@param {type:"number"}
#@markdown - specifies the temperature for applying bias towards MPNN prediction.
#@markdown   *low (<= 0.1)* - explore less, stick to MPNN probs, *high(>=4)* - explore more to get better AF metric


def prepare_af_bias_start_seq(unconditional_probs, af_model, mpnn_bias_temp):
    max_seq = np.argmax(unconditional_probs, axis=1)

    designed_chain_len = af_model._lengths[0]
    bias = np.zeros((len(af_model._wt_aatype), 20), dtype=np.float32)
    bias[np.arange(len(bias)), af_model._wt_aatype] = 1e8
    bias[:designed_chain_len] = unconditional_probs[:designed_chain_len] / mpnn_bias_temp
    return bias, max_seq


def af2mpnn(self):
  atom_idx = tuple(residue_constants.atom_order[k] for k in ["N","CA","C","O"])
  X = self._inputs["batch"]["all_atom_positions"][:,atom_idx]
  mask = self._inputs["batch"]["all_atom_mask"][:,1]
  inputs ={"X":X,
           "mask":mask,
           "residue_idx":self._inputs["residue_index"],
           "chain_idx":self._inputs["asym_id"],
           "key":self.key()}  
  return inputs

In [None]:
clear_mem()
mpnn_model = mk_mpnn_model()

best_metric = "dgram_cce" if use_templates else "rmsd"
af_design_model = mk_af_model(protocol="fixbb", best_metric=best_metric, use_templates=use_templates)

chain_list = [designed_chain]
if fixed_chains:
    chain_list += fixed_chains.split(",")
af_design_model.prep_inputs(pdb_filename=pdb_path, chain=",".join(chain_list))

In [None]:
mpnn_model.get_af_inputs(af_design_model)
seq_logits = mpnn_model.get_unconditional_logits()[:,:20]

bias, start_seq = prepare_af_bias_start_seq(seq_logits, af_design_model, mpnn_bias_temp)
af_design_model.restart(seed=0)
af_design_model.set_seq(seq=start_seq, bias=bias)
if use_templates:
    af_design_model.set_opt("template",dropout=0.15)
af_design_model.set_weights(pae=0.01,plddt=0.01)
af_design_model.design_3stage()

best = af_design_model._tmp['best']
designed_chain_len = af_design_model._lengths[0]
seqid = np.mean(best['aux']['aatype'][:designed_chain_len] == af_design_model._wt_aatype[:designed_chain_len])
print(f"{best_metric}: {best['metric']:.3f}, designed chain seqid: {seqid:.3f}")

In [None]:
af_design_model.plot_traj()

In [None]:
af_design_model.save_pdb(f"{af_design_model.protocol}.pdb")
af_design_model.plot_pdb()

In [None]:
HTML(af_design_model.animate())