<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/main/af/examples/af2cycler.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# The AF2cycler
This notebook contains the code to run the af2cycler and use for improvement of suboptimal designed protein backbones.
Based on:

**Alphafold2 refinement improves designability of large de novo proteins**

Christopher Frank, Dominik Schiwietz, Lara FuÃŸ, Sergey Ovchinnikov and Hendrik Dietz

We recommend to run this Notebook with at leat a L4 or better a A100 GPU as the GPU memeory needed for ESMFold is quite significant



In [None]:
#@title setup
%%time
import os
if not os.path.isdir("params"):
  # get code
  os.system("pip -q install pyppeteer nest_asyncio")
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git")
  # for debugging
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")
  # download params
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
  os.system("tar -xf alphafold_params_2022-12-06.tar -C params")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.mpnn import mk_mpnn_model

from IPython.display import HTML
from google.colab import files
import numpy as np

import requests, time
if not os.path.isfile("TMscore"):
  os.system("wget -qnc https://zhanggroup.org/TM-score/TMscore.cpp")
  os.system("g++ -static -O3 -ffast-math -lm -o TMscore TMscore.cpp")
def tmscore(x,y):
  # pass to TMscore
  output = os.popen(f'./TMscore {x} {y}')
  # parse outputs
  parse_float = lambda x: float(x.split("=")[1].split()[0])
  o = {}
  for line in output:
    line = line.rstrip()
    if line.startswith("RMSD"): o["rms"] = parse_float(line)
    if line.startswith("TM-score"): o["tms"] = parse_float(line)
    if line.startswith("GDT-TS-score"): o["gdt"] = parse_float(line)
  return o

import asyncio
import nest_asyncio
from pyppeteer import launch
import base64

# Apply nest_asyncio to enable nested event loops
nest_asyncio.apply()

async def fetch_blob_content(page, blob_url):
  blob_to_base64 = """
  async (blobUrl) => {
      const blob = await fetch(blobUrl).then(r => r.blob());
      return new Promise((resolve) => {
          const reader = new FileReader();
          reader.onloadend = () => resolve(reader.result);
          reader.readAsDataURL(blob);
      });
  }
  """
  base64_data = await page.evaluate(blob_to_base64, blob_url)
  _, encoded = base64_data.split(',', 1)
  return base64.b64decode(encoded)

async def extract_pdb_file_download_link_and_content(url):
  browser = await launch(headless=True, args=['--no-sandbox', '--disable-setuid-sandbox'])
  page = await browser.newPage()
  await page.goto(url, {'waitUntil': 'networkidle0'})
  elements = await page.querySelectorAll('a.btn.bg-purple')
  for element in elements:
      href = await page.evaluate('(element) => element.getAttribute("href")', element)
      if 'blob:https://esmatlas.com/' in href:
          content = await fetch_blob_content(page, href)
          await browser.close()
          return href, content
  await browser.close()
  return "No PDB file link found.", None

def esmfold_api(sequence):
  url = f'https://esmatlas.com/resources/fold/result?fasta_header=%3Eunnamed&sequence={sequence}'
  result = asyncio.get_event_loop().run_until_complete(extract_pdb_file_download_link_and_content(url))
  if result[1]:
      pdb_str = result[1].decode('utf-8')
      return pdb_str
  else:
      return "Failed to retrieve PDB content."

import jax
import jax.numpy as jnp
from colabdesign.af.alphafold.common import residue_constants

if not os.path.exists('/content/in/'):
    os.mkdir('/content/in/')

import py3Dmol

def visualize_pdb_overlay(pdb1_path, pdb2_path):
    viewer = py3Dmol.view(width=800, height=600)

    with open(pdb1_path, 'r') as f:
        pdb1_data = f.read()
    viewer.addModel(pdb1_data, 'pdb')
    viewer.setStyle({'model': 0}, {'cartoon': {'color': 'grey'}})

    with open(pdb2_path, 'r') as f:
        pdb2_data = f.read()
    viewer.addModel(pdb2_data, 'pdb')
    viewer.setStyle({'model': 1}, {'cartoon': {'color' : 'red'}})

    viewer.zoomTo()
    viewer.show()




# Chroma Design

On the start of this pipeline is the creation of draft backbones using Chroma or any other design method you want to use. We suggest you check out this two notebooks on Chroma to generate your desired proteins:
**Chroma Quickstart**
https://colab.research.google.com/github/generatebio/chroma/blob/main/notebooks/ChromaDemo.ipynb

**Chroma API Tutorial**
https://colab.research.google.com/github/generatebio/chroma/blob/main/notebooks/ChromaAPI.ipynb

If you have your PDB files please upload them into the **in/** folder and proceed


In [None]:
#@title Monomer af2cycling
#@markdown The af2cycler takes in the Chroma design and returns a new pdb file with improved structure
iterations = 10 #@param {type:"integer"}
#@markdown The af2cycled model is shown in red, while the chroma model is shown in grey



import numpy as np
import warnings

warnings.simplefilter(action='ignore', category=FutureWarning)

import os, re
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.mpnn import mk_mpnn_model

import os, re
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.mpnn import mk_mpnn_model

import numpy as np

def sample_gumbel(shape, eps=1e-20):
  """Sample from Gumbel(0, 1)"""
  U = np.random.uniform(size=shape)
  return -np.log(-np.log(U + eps) + eps)


clear_mem()

iters = iterations
in_path = '/content/in/'
out_path = 'out/'
if not os.path.exists('/content/in/out/'):
    os.mkdir('/content/in/out/')


starting_seq = ""
starting_seq = re.sub("[^A-Z]", "", starting_seq.upper())


file_list = os.listdir(in_path)

clear_mem()
mpnn_model = mk_mpnn_model()
af_model = mk_afdesign_model(protocol="fixbb",use_templates=True,use_initial_atom_pos=True,use_initial_guess=True)

for file_name in file_list:
    if file_name[-1] == 'b':

        in_pdb = in_path + file_name
        out_pdb = in_path + out_path + 'Out_' + file_name




        af_model.prep_inputs(pdb_filename=in_pdb, chain='A')

        starting_seq = af_model._inputs['batch']["aatype"]

        iterations = iters

        use_dropout = True
        num_recycles = 0

        mpnn_mode = "conditional"
        cmap_seqsep = 9

        cmap_num = 2
        #cmap_cutoff = 14
        L = sum(af_model._lengths)
        af_model.restart(mode="gumbel")
        af_model._args["clear_prev"] = False
        #af_model.set_opt(cmap_cutoff=cmap_cutoff)
        af_model.set_weights(helix=1e-8)
        # gather info about inputs
        if "offset" in af_model._inputs:
            offset = af_model._inputs
        else:
            idx = af_model._inputs["residue_index"]
            offset = idx[:, None] - idx[None, :]
        # initialize sequence
        if len(starting_seq) > 1:
            af_model.set_seq(seq=starting_seq)
        # initialize coordinates
        af_model._inputs.pop("prev", None)
        init = af_model._inputs["batch"]['all_atom_positions'].copy()

        save_best = False
        for k in range(iterations):

            if k > (iterations - 10):
                use_dropout = False
                save_best = True

            # denoise
            aux = af_model.predict(return_aux=True, verbose=False,
                                   dropout=use_dropout,
                                   num_recycles=num_recycles)

            #af_model._inputs["prev"] = aux["prev"]
            #af_model._inputs["prev"]["prev_msa_first_row"] *= 0
            #af_model._inputs["prev"]["prev_pos"] *= 0

            cmap = aux["cmap"] * (np.abs(offset) > cmap_seqsep)
            conf = np.sort(cmap)[:, -cmap_num:].mean(-1)

            plddt = aux["plddt"]
            seq = aux["seq"]["hard"][0].argmax(-1)
            xyz = aux["atom_positions"].copy()
            # update inputs
            af_model._inputs["batch"]["aatype"] = seq
            af_model._inputs["batch"]["all_atom_positions"] = xyz

            if mpnn_mode != "none":

                mpnn_model.get_af_inputs(af_model)
                opt = {"mask": np.sqrt(conf)}
                if mpnn_mode == "unconditional":
                    opt["ar_mask"] = np.zeros((L, L))
                mpnn_out = mpnn_model.score(**opt)
                mpnn_logits = mpnn_out["logits"][:, :20]
                aux["log"]["mpnn"] = mpnn_out["score"]

                c = conf[:, None]

                new_logits = (1 - c) * sample_gumbel(mpnn_logits.shape) + c * mpnn_logits

                af_model._params["seq"] = 0.9 * af_model._params["seq"] + 0.1 * new_logits

            # save results
            af_model._save_results(aux, save_best=save_best)
            af_model._k += 1

        af_model.save_pdb(out_pdb)
        visualize_pdb_overlay(in_pdb, out_pdb)



In [None]:
HTML(af_model.animate())

In [None]:
#@markdown #Redesign with solubleMPNN for ESMFold prediction
#@markdown The standard manuscript settings were 8 sequences, 0.1 sampling temperature and the removal of cysteines
import pickle
import pandas as pd
in_path = '/content/in/out/'
out_path = 'out_sMPNN/'
if not os.path.exists(out_path):
    os.mkdir(out_path)

file_list = os.listdir(in_path)


num_seqs = 8 #@param ["8", "16", "32", "64"] {type:"raw"}
mpnn_sampling_temp = 0.1 #@param ["0.0001", "0.1", "0.15", "0.2", "0.25", "0.3", "0.5", "1.0"] {type:"raw"}
rm_aa = "C" #@param {type:"string"}
use_solubleMPNN = True #@param {type:"boolean"}
#@markdown - `mpnn_sampling_temp` - control diversity of sampled sequences. (higher = more diverse).
#@markdown - `rm_aa='C'` - do not use [C]ysteines.
#@markdown - `use_solubleMPNN` - use weights trained only on soluble proteins.
#@markdown

from colabdesign.shared.protein import alphabet_list as chain_list
mpnn_model = mk_mpnn_model()


for file in file_list:
  if file[-4:] == '.pdb':

    in_file1 = in_path + file
    mpnn_model.prep_inputs(pdb_filename=in_file1,
                          chain='A',
                          rm_aa=rm_aa,weights = "soluble")
    out = mpnn_model.sample(num=num_seqs//8,
                            batch=8,
                            temperature=mpnn_sampling_temp)
    for seq,score in zip(out["seq"],out["score"]):
      print(score,seq.split("/")[0])
    df = pd.DataFrame(out["seq"])

    # Define the output path for saving the sequences as a .pkl file
    output_pkl_file = out_path + file[:-4] + "_sequences.pkl"

    # Save the DataFrame to a .pkl file
    with open(output_pkl_file, 'wb') as f:
        pickle.dump(df, f)

In [None]:
#@markdown #Run ESMFold to test designability

#@markdown **This cell is a little bit tricky. The problem is the compatibility between the JAX and PYTORCH frameworks between ColabDesign and ESMFold and GPU memory requirements**

#@markdown Go to *Runtime* >> *Restart Session*

#@markdown then Run this cell

#@markdown Runtime 2-5 min

#@markdown This cells runs ESMFold from huggingface and automatically calculates the RMSD to the designed backbone
#@markdown NOTE: GPU memory can be a big problem here. If you get memory errors please restart the runtime and run this cell again. It should be self contained. Additionally, after finish the ESMFold prediction rerun the setup cell

import os
import pandas as pd
from Bio.PDB import PDBParser, Superimposer
import pickle
import torch
import numpy as np
from transformers import AutoTokenizer, EsmForProteinFolding
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37

import py3Dmol

def visualize_pdb_overlay(pdb1_path, pdb2_path):
    viewer = py3Dmol.view(width=800, height=600)

    with open(pdb1_path, 'r') as f:
        pdb1_data = f.read()
    viewer.addModel(pdb1_data, 'pdb')
    viewer.setStyle({'model': 0}, {'cartoon': {'color': 'grey'}})

    with open(pdb2_path, 'r') as f:
        pdb2_data = f.read()
    viewer.addModel(pdb2_data, 'pdb')
    viewer.setStyle({'model': 1}, {'cartoon': {'color' : 'red'}})

    viewer.zoomTo()
    viewer.show()





tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)

device = 'cuda:0'
model = model.cuda(device)
model.esm = model.esm.half()
model.trunk.set_chunk_size(64)
torch.backends.cuda.matmul.allow_tf32 = True

def convert_outputs_to_pdb(outputs):
    final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
    outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
    final_atom_positions = final_atom_positions.cpu().numpy()
    final_atom_mask = outputs["atom37_atom_exists"]
    pdbs = []
    for i in range(outputs["aatype"].shape[0]):
        aa = outputs["aatype"][i]
        pred_pos = final_atom_positions[i]
        mask = final_atom_mask[i]
        resid = outputs["residue_index"][i] + 1
        pred = OFProtein(
            aatype=aa,
            atom_positions=pred_pos,
            atom_mask=mask,
            residue_index=resid,
            b_factors=outputs["plddt"][i],
            chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
        )
        pdbs.append(to_pdb(pred))
    return pdbs

def calculate_ca_rmsd(pdb_file1, pdb_file2):
    parser = PDBParser(QUIET=True)

    structure1 = parser.get_structure("Protein1", pdb_file1)
    structure2 = parser.get_structure("Protein2", pdb_file2)

    ca_atoms1 = [atom for atom in structure1.get_atoms() if atom.get_name() == "CA"]
    ca_atoms2 = [atom for atom in structure2.get_atoms() if atom.get_name() == "CA"]

    super_imposer = Superimposer()
    super_imposer.set_atoms(ca_atoms1, ca_atoms2)
    super_imposer.apply(structure2.get_atoms())
    rmsd = super_imposer.rms
    return rmsd

def process_sequences(seq_list, pdb_file,in_path_pdb):
    lowest_rmsd = float('inf')
    lowest_rmsd_data = None
    pdb_id = pdb_file[16:-4]
    out_ss_path = in_path_pdb + "output/"

    if not os.path.exists(out_ss_path):
        os.mkdir(out_ss_path)

    for test_protein in seq_list:
        data = {}
        tokenized_input = tokenizer([test_protein], return_tensors="pt", add_special_tokens=False)['input_ids']
        tokenized_input = tokenized_input.cuda(device)

        with torch.no_grad():
            output = model(tokenized_input)

        data['out'] = output
        data["plddt"] = torch.mean(output['plddt']).item()
        data['pae'] = torch.mean(output['predicted_aligned_error']).item()

        pdb_data = convert_outputs_to_pdb(output)
        tmp_pdb_file = os.path.join(out_ss_path, "TMP.pdb")

        with open(tmp_pdb_file, 'w') as file:
            for line in pdb_data:
                file.write(line)

        data['rmsd'] = calculate_ca_rmsd(tmp_pdb_file, pdb_file)
        print(f'Sequence: {test_protein}, plddt: {data["plddt"]}, PAE: {data["pae"]}, RMSD: {data["rmsd"]}')

        if data['rmsd'] < lowest_rmsd:
            lowest_rmsd = data['rmsd']
            lowest_rmsd_data = data

    if lowest_rmsd_data is not None:
        print(f'Lowest RMSD: {lowest_rmsd}')
        best_pdb_data = convert_outputs_to_pdb(lowest_rmsd_data['out'])
        best_pdb_file = os.path.join(out_ss_path, f"{pdb_id}_best_structure.pdb")

        with open(best_pdb_file, 'w') as file:
            for line in best_pdb_data:
                file.write(line)

        original_dict = lowest_rmsd_data
        key_to_exclude = 'out'
        data_out = {k: v for k, v in original_dict.items() if k != key_to_exclude}

        with open(os.path.join(out_ss_path, f"{pdb_id}_best_structure_data.pkl"), 'wb') as f:
            pickle.dump(data_out, f)

        return lowest_rmsd, best_pdb_file, data_out

    return None, None, None


in_path= '/content/out_sMPNN/'
file_list = os.listdir(in_path)

for file in file_list:
  if file[-1] =='l':



    output_pkl_file = in_path + file

    with open(output_pkl_file, 'rb') as f:
        seq = pickle.load(f)
    seq_list = []
    for i in np.asarray(seq):
      seq_list.append(i[0])

    pdb_file = "/content/in/out/" + file[:-14] + '.pdb'
    print(seq_list)







    lowest_rmsd, best_pdb_file, best_data = process_sequences(seq_list, pdb_file,in_path)
    if lowest_rmsd is not None:
        print(f"Lowest RMSD: {lowest_rmsd}, Best PDB file: {best_pdb_file}")
    else:
        print("No valid result found.")
    #visualize_pdb_overlay(pdb_file, best_pdb_file)


In [None]:
#@markdown #Optional: Run AF2 based designability test


#@markdown Rerun **setup cell** if you tested ESMFold prediction before!

#@markdown This cell predicts the solubleMPNN generated seuqneces with AF2 with Initial Guess & All atom initialisation
import re
clear_mem()

in_path = '/content/in/'
out_path = 'out/'
if not os.path.exists('/content/in/out/'):
    os.mkdir('/content/in/out/')


starting_seq = ""
starting_seq = re.sub("[^A-Z]", "", starting_seq.upper())


file_list = os.listdir(in_path)

clear_mem()

af_model = mk_afdesign_model(protocol="fixbb",use_initial_atom_pos=True,use_initial_guess=True)


def process_sequences(seq_list, pdb_file,in_path_pdb):
    lowest_rmsd = float('inf')
    lowest_rmsd_data = None
    pdb_id = pdb_file[16:-4]
    out_ss_path = in_path_pdb + "output_AF2/"

    if not os.path.exists(out_ss_path):
        os.mkdir(out_ss_path)
    kk=0
    for test_protein in seq_list:
        data = {}
        af_model.prep_inputs(pdb_filename=pdb_file, chain='A')
        af_model.predict(seq=test_protein,num_recycles=3)



        data["plddt"] = af_model.aux['losses']['plddt']
        data['pae'] = af_model.aux['losses']['pae']*31
        data['rmsd'] = af_model.aux['losses']['rmsd']
        af_model.save_pdb(f'{out_ss_path}{pdb_id}_{kk}.pdb')
        print(f'Sequence: {test_protein}, plddt: {data["plddt"]}, PAE: {data["pae"]}, RMSD: {data["rmsd"]}')

        if data['rmsd'] < lowest_rmsd:
            lowest_rmsd = data['rmsd']
            lowest_rmsd_data = data

    if lowest_rmsd_data is not None:
        print(f'Lowest RMSD: {lowest_rmsd}')

        best_pdb_file = os.path.join(out_ss_path, f"{pdb_id}_best_structure.pdb")

        return lowest_rmsd, best_pdb_file

    return None, None, None


in_path= '/content/out_sMPNN/'
file_list = os.listdir(in_path)

for file in file_list:
  if file[-1] =='l':



    output_pkl_file = in_path + file

    with open(output_pkl_file, 'rb') as f:
        seq = pickle.load(f)
    seq_list = []
    for i in np.asarray(seq):
      seq_list.append(i[0])

    pdb_file = "/content/in/out/" + file[:-14] + '.pdb'
    print(seq_list)







    lowest_rmsd, best_pdb_file = process_sequences(seq_list, pdb_file,in_path)
    if lowest_rmsd is not None:
        print(f"Lowest RMSD: {lowest_rmsd}, Best PDB file: {best_pdb_file}")
    else:
        print("No valid result found.")
    #visualize_pdb_overlay(pdb_file, best_pdb_file)


