<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/af/examples/af_pseudo_diffusion_recycle.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#AF_pseudo_diffusion + proteinMPNN
Hacking AlphaFold to be a diffusion model (for backbone generation) via recycling mechanism. At each step add logits from proteinMPNN.


**WARNING**: This notebook is experimental, designed as a control. Not intended for practical use at this stage.

In [None]:
#@title setup
%%time
import os
if not os.path.isdir("params"):
  # get code
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1")
  # for debugging
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")
  # download params
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
  os.system("tar -xf alphafold_params_2022-12-06.tar -C params")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os, re
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.mpnn import mk_mpnn_model
from colabdesign.af.alphafold.common import residue_constants
from colabdesign.shared.protein import _np_get_cb

from IPython.display import HTML
from google.colab import files
import numpy as np
import jax.numpy as jnp
import jax
from scipy.special import softmax, log_softmax

import tqdm.notebook
TQDM_BAR_FORMAT = '{l_bar}{bar}| {n_fmt}/{total_fmt} [elapsed: {elapsed} remaining: {remaining}]'

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"

def sample_gumbel(shape, eps=1e-20): 
  """Sample from Gumbel(0, 1)"""
  U = np.random.uniform(size=shape)
  return -np.log(-np.log(U + eps) + eps)

In [None]:
#@title initialize the model
length = 100 #@param {type:"integer"}
#@markdown Provide a starting point (optional)
starting_seq = "" #@param {type:"string"}
starting_seq = re.sub("[^A-Z]", "", starting_seq.upper())
#@markdown - if `starting_seq` provided the `length` option will be overwritten by length of starting sequence.

if len(starting_seq) > 0:
  length = len(starting_seq)

clear_mem()
af_model = mk_afdesign_model(protocol="hallucination")
af_model.prep_inputs(length=length)
mpnn_model = mk_mpnn_model()

af_model_test = mk_afdesign_model(protocol="fixbb", best_metric="rmsd")
mpnn_model_test = mk_mpnn_model()

print("lengths",af_model._lengths)

In [None]:
#@title run protocol
#@markdown Optimization options
iterations = 100 #@param ["50", "100", "200"] {type:"raw"}

#AlphaFold options
use_dropout = True
num_recycles = 0

#@markdown proteinMPNN options (set to `none` to disable)
mpnn_mode = "conditional" #@param ["none","conditional", "unconditional"]

#@markdown proteinMPNN contact map masking options
cmap_seqsep = 6 #@param {type:"raw"}
cmap_num = 1 #@param {type:"integer"}
cmap_cutoff = 8 #@param {type:"raw"}

L = sum(af_model._lengths)
af_model.restart(mode="gumbel")
af_model._args["clear_prev"] = False
af_model.set_opt(cmap_cutoff=cmap_cutoff)
af_model.set_weights(helix=1e-8)

# gather info about inputs
if "offset" in af_model._inputs:
  offset = af_model._inputs
else:
  idx = af_model._inputs["residue_index"]
  offset = idx[:,None] - idx[None,:]

# initialize sequence
if len(starting_seq) > 1:
  af_model.set_seq(seq=starting_seq)

# initialize coordinates
af_model._inputs.pop("prev",None)
af_model._inputs["batch"] = {"aatype":np.zeros(L).astype(int),
                             "all_atom_mask":np.zeros((L,37)),
                             "all_atom_positions":np.zeros((L,37,3))}

save_best = False
for k in range(iterations):

  if k > (iterations - 10):
    use_dropout = False
    save_best = True


  # denoise
  aux = af_model.predict(return_aux=True, verbose=False,
                         dropout=use_dropout,
                         num_recycles=num_recycles)
  af_model._inputs["prev"] = aux["prev"]
  af_model._inputs["prev"]["prev_msa_first_row"] *= 0
  af_model._inputs["prev"]["prev_pos"] *= 0

  # per position confidence
  cmap = aux["cmap"] * (np.abs(offset) > cmap_seqsep)
  conf = np.sort(cmap)[:,-cmap_num:].mean(-1)

  # gather features
  plddt = aux["plddt"]
  seq = aux["seq"]["hard"][0].argmax(-1)
  xyz = aux["atom_positions"].copy()
      
  # update inputs    
  af_model._inputs["batch"]["aatype"] = seq
  af_model._inputs["batch"]["all_atom_positions"] = xyz

  # add logits from proteinmpnn at each stage
  if mpnn_mode != "none":    
    mpnn_model.get_af_inputs(af_model)
    opt = {"mask":np.sqrt(conf)}    
    if mpnn_mode == "unconditional":
      opt["ar_mask"] = np.zeros((L,L))
    mpnn_out = mpnn_model.score(**opt)
    mpnn_logits = mpnn_out["logits"][:,:20]
    aux["log"]["mpnn"] = mpnn_out["score"]
    
    c = conf[:,None]
    new_logits = (1 - c) * sample_gumbel(mpnn_logits.shape) + c * mpnn_logits
    af_model._params["seq"] = 0.9 * af_model._params["seq"] + 0.1 * new_logits

  # save results
  af_model._save_results(aux, save_best=save_best)
  af_model._k += 1

af_model.save_pdb("init.pdb")

In [None]:
af_model.plot_pdb()
af_model.get_seqs()

In [None]:
HTML(af_model.animate(dpi=100))

In [None]:
#@title sample new sequences using proteinMPNN and rescore with alphafold (w/o template)
#@markdown #### MPNN Options
num_seqs = 16 #@param ["8", "16", "32", "64", "128", "256", "512", "1024"] {type:"raw"}
sampling_temp = 0.1 
#@markdown #### AlphaFold Options
alphafold_model = "model_4_ptm" #@param ["model_1_ptm", "model_2_ptm", "model_3_ptm", "model_4_ptm", "model_5_ptm"]
num_recycles = 3 #@param ["0", "1", "2", "3"] {type:"raw"}
import pandas as pd

# zero out template inputs
af_model_test.prep_inputs("init.pdb")
mpnn_model_test.get_af_inputs(af_model_test)
out = mpnn_model_test.sample(num=num_seqs//8, batch=8,
                             temperature=sampling_temp)
af_terms = ["plddt","ptm","pae","rmsd","dgram_cce"]
for k in af_terms: out[k] = []
os.system("mkdir -p output/all_pdb")

with tqdm.notebook.tqdm(total=out["S"].shape[0], bar_format=TQDM_BAR_FORMAT) as pbar:
  with open("design.fasta","w") as fasta:
    for n in range(num_seqs):
      seq = out["seq"][n]
      af_model_test.predict(seq=seq,
                            num_recycles=num_recycles,
                            num_models=1,
                            verbose=False,
                            models=alphafold_model)

      for t in af_terms:
        out[t].append(af_model_test.aux["log"][t])
      out["pae"][-1] = out["pae"][-1] * 31
      af_model_test._save_results(save_best=True, verbose=False)
      af_model_test.save_current_pdb(f"output/all_pdb/n{n}.pdb")
      af_model_test._k += 1

      line = f'>mpnn:{out["score"][n]:.3f}_plddt:{out["plddt"][n]:.3f}_ptm:{out["ptm"][n]:.3f}_pae:{out["pae"][n]:.3f}\n{out["seq"][n]}'
      fasta.write(line+"\n")
      pbar.update(1)

af_model_test.save_pdb("final.pdb")

labels = ["score"] + af_terms + ["seq"]
data = [[out[k][n] for k in labels] for n in range(num_seqs)]
labels[0] = "mpnn"

df = pd.DataFrame(data, columns=labels)
df.to_csv('output/mpnn_results.csv')
df.round(3).sort_values("rmsd")

In [None]:
af_model_test.plot_pdb()
af_model_test.get_seqs()