<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/af/examples/af_pseudo_diffusion.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

#AF_pseudo_diffusion + proteinMPNN
Hacking AlphaFold to be a diffusion model (for backbone generation). At each step add logits from proteinMPNN.


**WARNING**: This notebook is experimental, designed as a control. Not intended for practical use at this stage.

---
**NEW**
For latest version of notebook (reconfigured to work in distogram space) go here:
[af_pseudo_diffusion_dgram.ipynb](https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/af/examples/af_pseudo_diffusion_dgram.ipynb)

In [None]:
#@title setup
%%time
import os
if not os.path.isdir("params"):
  # get code
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1")
  # for debugging
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")
  # download params
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
  os.system("tar -xf alphafold_params_2022-12-06.tar -C params")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os, re
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.mpnn import mk_mpnn_model
from IPython.display import HTML
from google.colab import files
import numpy as np

def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"

def sample_gumbel(shape, eps=1e-20): 
  """Sample from Gumbel(0, 1)"""
  U = np.random.uniform(size=shape)
  return -np.log(-np.log(U + eps) + eps)

In [None]:
#@title initialize the model
length = 100 #@param {type:"integer"}
#@markdown Provide a starting point (optional)
starting_seq = "" #@param {type:"string"}
starting_seq = re.sub("[^A-Z]", "", starting_seq.upper())
#@markdown - if `starting_seq` provided the `length` option will be overwritten by length of starting sequence.
use_starting_pdb = False #@param {type:"boolean"}
pdb = "" #@param {type:"string"}
#@markdown - specify PDB or Uniprot code or leave pdb blank for upload prompt
chains = "A" #@param {type:"string"}
#@markdown - (example `A` or `A,B` for complexes)
fix_pos = "" #@param {type:"string"}
#@markdown - specify which positions to keep fixed in the sequence (example: `1,2-10`)
#@markdown - you can also specify chain specific constraints (example: `A1-10,B1-20`)
#@markdown - you can also specify to fix entire chain(s) (example: `A`)

if len(starting_seq) > 0:
  length = len(starting_seq)

clear_mem()
if use_starting_pdb:
  af_model = mk_afdesign_model(protocol="fixbb", use_templates=True)
  af_model.prep_inputs(get_pdb(pdb), chains, fix_pos=fix_pos) 
else:
  af_model = mk_afdesign_model(protocol="hallucination", use_templates=True)
  af_model.prep_inputs(length=length)
mpnn_model = mk_mpnn_model()
print("lengths",af_model._lengths)

In [None]:
#@title run protocol
#@markdown Optimization options
iterations = 100 #@param ["50", "100"] {type:"raw"}
use_xyz_noise = True #@param {type:"boolean"}
use_seq_noise = True #@param {type:"boolean"}
use_dropout = True #@param {type:"boolean"}
use_plddt = True #@param {type:"boolean"}
store_denoised = True #@param {type:"boolean"}
#@markdown - this does not change results, but toggle between storing noised vs denoised coordinates for animation.

#@markdown AlphaFold options
sample_models = False #@param {type:"boolean"}
rm_template_seq = True #@param {type:"boolean"}

#@markdown proteinMPNN options (disable to keep sequence the same)
use_mpnn = True #@param {type:"boolean"}
mpnn_mode = "conditional" #@param ["conditional", "unconditional"]

af_model.restart(mode="gumbel")

if len(starting_seq) > 1:
  af_model.set_seq(seq=starting_seq)
elif use_starting_pdb:
  af_model.set_seq(mode="wildtype")

af_model._inputs["rm_template_seq"] = rm_template_seq
L = sum(af_model._lengths)

af_model._inputs["bias"] = np.zeros((L,20))

if not use_starting_pdb:
  af_model._inputs["batch"] = {"aatype":np.zeros(L).astype(int),
                               "all_atom_mask":np.zeros((L,37)),
                               "all_atom_positions":np.zeros((L,37,3))}

for k in range(iterations):
  # add noise
  if use_seq_noise:
    af_model._inputs["bias"] = 0.1 * sample_gumbel((L,20))

  if use_xyz_noise:
    n = np.random.normal(size=(L,37,3)) * (1-k/iterations)
    af_model._inputs["batch"]["all_atom_positions"] += n
    if not store_denoised and k > 0:
      af_model._tmp["traj"]["xyz"][-1] += n[:,1]

  # denoise
  aux = af_model.predict(return_aux=True, verbose=False,
                         sample_models=sample_models, dropout=use_dropout)
  plddt = af_model.aux["plddt"]
  # update inputs
  af_model._inputs["batch"]["aatype"] = af_model.aux["seq"]["hard"].argmax(-1)[0]
  af_model._inputs["batch"]["all_atom_mask"][:,:4] = 1
  if use_plddt:
    af_model._inputs["batch"]["all_atom_mask"][:,(1,3)] = np.sqrt(plddt)[:,None]
  af_model._inputs["batch"]["all_atom_positions"] = af_model.aux["atom_positions"].copy()

  # add logits from proteinmpnn at each stage
  if use_mpnn:
    mpnn_model.get_af_inputs(af_model)
    opt = {} if mpnn_mode == "conditional" else {"ar_mask":np.zeros((L,L))}
    mpnn_out = mpnn_model.score(**opt)
    aux["log"]["mpnn"] = mpnn_out["score"]
    mpnn_logits = mpnn_out["logits"][:,:20]
    m = (k/iterations)
    af_model._params["seq"] = (1-m) * af_model._params["seq"] + m * mpnn_logits

    # save results
  af_model._save_results(aux)
  af_model._k += 1

In [None]:
af_model.plot_pdb()

In [None]:
af_model.save_pdb("tmp.pdb")
af_model.get_seqs()

In [None]:
HTML(af_model.animate(dpi=100))