<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/main/af/examples/RSO.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>




#Protein Design using Relaxed Sequence Optimization


**Scalable protein design using optimization in a relaxed sequence space**




Christopher Frank, Ali Khoshouei, Lara FuÃŸ, Lara Weber Dominik Schiewitz,Zhixuan Zhao, Motoyuki Hattori, Yosta de Stigter, Shihao Feng, Sergey Ovchinnikov and Hendrik Dietz


This notebook contains code to run relaxed sequence optimisation for de novo protein design as described in the manuscript. There are additional options to modify the pipeline according to ones needs

We recommend using at least an L4 GPU to run this notebook, as the free T4 GPU struggles with larger proteins

Alternativly a local installation of ColabDesign is strongly recommendet, especially for the design of larger proteins.

For questions feel free to reach out to the authors


In [None]:
#@title setup
%%time
import os
if not os.path.isdir("params"):
  # get code
  os.system("pip -q install pyppeteer nest_asyncio")
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git")
  # for debugging
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")
  # download params
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
  os.system("tar -xf alphafold_params_2022-12-06.tar -C params")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.mpnn import mk_mpnn_model

from IPython.display import HTML
from google.colab import files
import numpy as np

import requests, time
if not os.path.isfile("TMscore"):
  os.system("wget -qnc https://zhanggroup.org/TM-score/TMscore.cpp")
  os.system("g++ -static -O3 -ffast-math -lm -o TMscore TMscore.cpp")
def tmscore(x,y):
  # pass to TMscore
  output = os.popen(f'./TMscore {x} {y}')
  # parse outputs
  parse_float = lambda x: float(x.split("=")[1].split()[0])
  o = {}
  for line in output:
    line = line.rstrip()
    if line.startswith("RMSD"): o["rms"] = parse_float(line)
    if line.startswith("TM-score"): o["tms"] = parse_float(line)
    if line.startswith("GDT-TS-score"): o["gdt"] = parse_float(line)
  return o

import asyncio
import nest_asyncio
from pyppeteer import launch
import base64

# Apply nest_asyncio to enable nested event loops
nest_asyncio.apply()

async def fetch_blob_content(page, blob_url):
  blob_to_base64 = """
  async (blobUrl) => {
      const blob = await fetch(blobUrl).then(r => r.blob());
      return new Promise((resolve) => {
          const reader = new FileReader();
          reader.onloadend = () => resolve(reader.result);
          reader.readAsDataURL(blob);
      });
  }
  """
  base64_data = await page.evaluate(blob_to_base64, blob_url)
  _, encoded = base64_data.split(',', 1)
  return base64.b64decode(encoded)

async def extract_pdb_file_download_link_and_content(url):
  browser = await launch(headless=True, args=['--no-sandbox', '--disable-setuid-sandbox'])
  page = await browser.newPage()
  await page.goto(url, {'waitUntil': 'networkidle0'})
  elements = await page.querySelectorAll('a.btn.bg-purple')
  for element in elements:
      href = await page.evaluate('(element) => element.getAttribute("href")', element)
      if 'blob:https://esmatlas.com/' in href:
          content = await fetch_blob_content(page, href)
          await browser.close()
          return href, content
  await browser.close()
  return "No PDB file link found.", None

def esmfold_api(sequence):
  url = f'https://esmatlas.com/resources/fold/result?fasta_header=%3Eunnamed&sequence={sequence}'
  result = asyncio.get_event_loop().run_until_complete(extract_pdb_file_download_link_and_content(url))
  if result[1]:
      pdb_str = result[1].decode('utf-8')
      return pdb_str
  else:
      return "Failed to retrieve PDB content."

import jax
import jax.numpy as jnp
from colabdesign.af.alphafold.common import residue_constants

In [None]:
#@title # Unconditional Generation (Custom)
#@markdown For a given length, generate/hallucinate a protein sequence that AlphaFold thinks folds into a well structured protein (high plddt, low pae, many contacts).
LENGTH = 100 #@param {type:"integer"}
#@markdown With copies you can specify the number of identical sequences design, resulting in homo oligomers. Copies = 1 is the standard, resulting in a monomer

COPIES = 1 #@param ["1", "2", "3", "4", "5", "6", "7", "8"] {type:"raw"}
MODE = "manuscript"

#@markdown Select the losses you want to use. For unconditional generation as reported in the manuscript use all the losses. To increase the diversity of designes remove confidence losses and/or increase the weight of the helix loss.

use_rg_loss = True #@param {type:"boolean"}
#@markdown A too strong rg loss can lead to problems and clashes. Use 0.1 for backbones smaller then 600 AA and 0.01 for larger proteins (0.001 for 1000 AA).
rg_weight = 0.1 #@param {type:"raw"}
use_helix_loss = True #@param {type:"boolean"}
use_con_loss = True #@param {type:"boolean"}
use_confidence_loss = True #@param {type:"boolean"}
#@markdown How many halluicnation iteration you want to perform. The standard in the manuscript is 100.

iters = 50 #@param ["100", "50", "30"] {type:"raw"}


#@markdown Select if you want to use the 'standard" ProteinMPNN weights or the soluble ones. The soluble ones usually result in higher in silico as well as experimental sucess, but will increase the negative net charge of the protein which sould potentially interfer with certain protein design problems. The manuscript settings are soluble MPNN

use_solubleMPNN = True #@param {type:"boolean"}
#@markdown Select this to use an experimental ProteinMPNN loss, also backpropagating through ProteinMPNN. This was not used in the manuscript

use_mpnn_loss = False #@param {type:"boolean"}
#@markdown

def add_rg_loss(self, weight=0.1):
  '''add radius of gyration loss'''
  def loss_fn(inputs, outputs):
    xyz = outputs["structure_module"]
    ca = xyz["final_atom_positions"][:,residue_constants.atom_order["CA"]]
    if self.protocol == "binder":
      ca = ca[-self._binder_len:]
    #This uses a scaled version of the rg loss, only looking at every 5th residue
    if MODE == "manuscript":
      ca = ca[::5]
    rg = jnp.sqrt(jnp.square(ca - ca.mean(0)).sum(-1).mean() + 1e-8)

    if MODE == "original":
      rg_th = 2.38 * ca.shape[0] ** 0.365
      rg = jax.nn.elu(rg - rg_th)
    return {"rg":rg}
  self._callbacks["model"]["loss"].append(loss_fn)
  self.opt["weights"]["rg"] = weight

def add_mpnn_loss(self, mpnn=0.1, mpnn_seq=0.0):
  '''
  add mpnn loss
  mpnn = maximize confidence of proteinmpnn
  mpnn_seq = push designed sequence to match proteinmpnn logits
  '''

  self._mpnn = mk_mpnn_model(weights = "soluble" if use_solubleMPNN else "original")
  def loss_fn(inputs, outputs, aux, key):

    # get structure
    atom_idx = tuple(residue_constants.atom_order[k] for k in ["N","CA","C","O"])
    I = {"S":           inputs["aatype"],
         "residue_idx": inputs["residue_index"],
         "chain_idx":   inputs["asym_id"],
         "X":           outputs["structure_module"]["final_atom_positions"][:,atom_idx],
         "mask":        outputs["structure_module"]["final_atom_mask"][:,1],
         "lengths":     self._lengths,
         "key":         key}

    if "offset" in inputs:
      I["offset"] = inputs["offset"]

    # set autoregressive mask
    L = sum(self._lengths)
    if self.protocol == "binder":
      I["ar_mask"] = 1 - np.eye(L)
      I["ar_mask"][-self._len:,-self._len:] = 0
    else:
      I["ar_mask"] = np.zeros((L,L))

    # get logits
    logits = self._mpnn._score(**I)["logits"][:,:20]
    if self.protocol == "binder":
      logits = logits[-self._len:]
    else:
      logits = logits[:self._len]
    aux["mpnn_logits"] = logits

    # compute loss
    log_q = jax.nn.log_softmax(logits)
    p = inputs["seq"]["hard"]
    q = jax.nn.softmax(logits)
    losses = {}
    losses["mpnn"] = -log_q.max(-1).mean()
    losses["mpnn_seq"] = -(p * jax.lax.stop_gradient(log_q)).sum(-1).mean()
    return losses

  self._callbacks["model"]["loss"].append(loss_fn)
  self.opt["weights"]["mpnn"] = mpnn
  self.opt["weights"]["mpnn_seq"] = mpnn_seq

clear_mem()
af_model = mk_afdesign_model(protocol="hallucination")
af_model.prep_inputs(length=LENGTH, copies=COPIES)

# add extra losses

if use_mpnn_loss: add_mpnn_loss(af_model)

print("length",af_model._lengths)
print("weights",af_model.opt["weights"])

In [None]:
#This cell runs the design loop. Run this in a for loop for design of multiple proteins

af_model.restart()
af_model.set_seq(mode=["gumbel","soft"])
if use_rg_loss:   add_rg_loss(af_model,rg_weight)
if use_helix_loss : af_model.set_weights(helix=-0.2)
if use_con_loss : af_model.set_weights(con=1.0)
if use_confidence_loss : af_model.set_weights(plddt=0.5, pae=0.5)
print("weights",af_model.opt["weights"])
af_model.design_logits(iters-10)
af_model.design_logits(10, save_best=True)

In [None]:
#This cell plots and saves the results as a pdb file
af_model.save_pdb(f"{af_model.protocol}.pdb")
af_model.plot_pdb()

In [None]:
HTML(af_model.animate())

In [None]:
af_model.get_seqs()

In [None]:
import pandas as pd
#@title # Designability test
#@markdown Test the designability of the backbone, taking in the backbone, generating sequences with solubleMPNN and predicting the sequence with AF2 in single sequence mode.
#@markdown Use Initial Guess (IG) and All Atom Initialisation (AA) for larger proteins

AA = False #@param {type:"boolean"}
IG = False #@param {type:"boolean"}
#@markdown NOTE: we remove cysteines from all designed proteins. Additionally for large proteins we also exclude methions to reduce the number of internal start codons

def designability_test(af_model_test, mpnn_model_test,
                       num_seqs=8, sampling_temp=0.1, num_recycles=3,
                       model_num=4, best_metric="rmsd",
                       in_pdb="init.pdb", out_pdb="final.pdb",
                       verbose=False):
    alphafold_model = f"model_{model_num}_ptm"

    af_model_test.prep_inputs(in_pdb)
    af_model_test.restart(rm_aa="C,M")
    af_model_test._args["best_metric"] = best_metric
    L = sum(af_model_test._lengths)
    mpnn_model_test.get_af_inputs(af_model_test)
    out = mpnn_model_test.sample(num=num_seqs // 8, batch=8,
                                 temperature=sampling_temp)

    af_terms = ["plddt", "ptm", "pae", "rmsd", "dgram_cce"]
    for k in af_terms: out[k] = []

    for n in range(num_seqs):
        seq = out["seq"][n]
        af_model_test.predict(seq=seq,
                              num_recycles=num_recycles,
                              num_models=1,
                              verbose=False,
                              models=alphafold_model)

        for k in af_terms: out[k].append(af_model_test.aux["log"][k])
        out["pae"][-1] = out["pae"][-1] * 31
        af_model_test._save_results(save_best=True, verbose=verbose)
        af_model_test._k += 1

    af_model_test.save_pdb(out_pdb)
    labels = ["score"] + af_terms + ["seq"]
    data = [[out[k][n] for k in labels] for n in range(num_seqs)]
    labels[0] = "mpnn"
    df = pd.DataFrame(data, columns=labels)
    return df

af_model_test = mk_afdesign_model(protocol="fixbb",best_metric="rmsd",use_initial_guess=IG,use_initial_atom_pos=AA,use_templates=False)
mpnn_model_test = mk_mpnn_model(weights="soluble")


lowest_rmsd = float('inf')
lowest_rmsd_data = None

in_pdb = f"{af_model.protocol}.pdb"
out_pdb = f"{af_model.protocol}_out.pdb"


out = designability_test(af_model_test, mpnn_model_test,
                    num_seqs=8, sampling_temp=0.1, num_recycles=3,
                    model_num=4, best_metric="rmsd",
                    in_pdb=in_pdb, out_pdb=out_pdb,
                    verbose=True)



In [None]:
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.af.alphafold.common import residue_constants
import jax
import jax.numpy as jnp
#@title # OPTIONAL Unconditional Generation (Manuscript Code)

#@markdown This code generates a sample of 10 unconditional proteins for lengths between 100 and 800 AA exactly as in the manuscript. For larger proteins CUDA_UNIFIED_MEMORY is needed. This can be done by localy running the code on a CUDA capeable GPU with sufficient memory (A100 80GB e.g.) and running the code with the environment variables XLA_PYTHON_CLIENT_MEM_FRACTION=100.0 TF_FORCE_UNIFIED_MEMORY=1
def rg_loss(inputs, outputs):
  positions = outputs["structure_module"]["final_atom_positions"]
  ca = positions[::5,residue_constants.atom_order["CA"]]
  center = ca.mean(0)
  rg = jnp.sqrt(jnp.square(ca - center).sum(-1).mean() + 1e-8)
  rg_th = 2.38 * ca.shape[0] ** 0.365
  rg = jax.nn.elu(rg - rg_th)
  return {"rg":rg}




for length in [100,200,300,400,500,600,700,800]:
  model = mk_afdesign_model(protocol="hallucination",loss_callback=rg_loss)
  model.prep_inputs(length=length)
  print("weights",model.opt["weights"])
  print('Starting up and compiling JAX model....')

  for i in range(10):
      model.restart(mode=["gumbel", "soft"],rm_aa="C")
      model.opt["weights"]["rg"] = 0.1
      if length > 600:
        model.opt["weights"]["rg"] = 0.01
      #model.opt["weights"]['helix'] = -0.1
      model.opt["weights"]['plddt'] = 1.0
      model.opt["weights"]['pae'] = 1.0
      model.opt["weights"]['helix'] = -0.1
      print("weights", model.opt["weights"])
      model.design_logits(100)

      #change the output path for local execution
      model.save_pdb(f"Hallo_{i}.pdb")

In [None]:
#@markdown #Redesign with ProteinMPNN for ESMFold prediction
#@markdown The standard manuscript settings were 8 sequences, 0.1 sampling temperature and the removal of cysteines
import pickle
num_seqs = 8 #@param ["8", "16", "32", "64"] {type:"raw"}
mpnn_sampling_temp = 0.1 #@param ["0.0001", "0.1", "0.15", "0.2", "0.25", "0.3", "0.5", "1.0"] {type:"raw"}
rm_aa = "C" #@param {type:"string"}
use_solubleMPNN = False #@param {type:"boolean"}
#@markdown - `mpnn_sampling_temp` - control diversity of sampled sequences. (higher = more diverse).
#@markdown - `rm_aa='C'` - do not use [C]ysteines.
#@markdown - `use_solubleMPNN` - use weights trained only on soluble proteins.
#@markdown

from colabdesign.shared.protein import alphabet_list as chain_list
mpnn_model = mk_mpnn_model()
mpnn_model.prep_inputs(pdb_filename=f"{af_model.protocol}.pdb",
                       chain=",".join(chain_list[:COPIES]),
                       homooligmer=COPIES>1,
                       rm_aa=rm_aa,
                       weights = "soluble" if use_solubleMPNN else"original")
out = mpnn_model.sample(num=num_seqs//8,
                        batch=8,
                        temperature=mpnn_sampling_temp)
for seq,score in zip(out["seq"],out["score"]):
  print(score,seq.split("/")[0])
df = pd.DataFrame(out["seq"])

# Define the output path for saving the sequences as a .pkl file
output_pkl_file = "redesigned_sequences.pkl"

# Save the DataFrame to a .pkl file
with open(output_pkl_file, 'wb') as f:
    pickle.dump(df, f)

In [None]:
#@markdown #Run ESMFold to test designability
#@markdown This cells runs ESMFold from huggingface and automatically calculates the RMSD to the designed backbone
#@markdown NOTE: GPU memory can be a big problem here. If you get memory errors please restart the runtime and run this cell again. It should be self contained. Additionally, after finish the ESMFold prediction rerun the setup cell

import os
import pandas as pd
from Bio.PDB import PDBParser, Superimposer
import pickle
import torch
import numpy as np
from transformers import AutoTokenizer, EsmForProteinFolding
from transformers.models.esm.openfold_utils.protein import to_pdb, Protein as OFProtein
from transformers.models.esm.openfold_utils.feats import atom14_to_atom37
output_pkl_file = "redesigned_sequences.pkl"

with open(output_pkl_file, 'rb') as f:
    seq = pickle.load(f)
seq_list = []
for i in np.asarray(seq):
  seq_list.append(i[0])

pdb_file = "hallucination.pdb"
print(seq_list)


tokenizer = AutoTokenizer.from_pretrained("facebook/esmfold_v1")
model = EsmForProteinFolding.from_pretrained("facebook/esmfold_v1", low_cpu_mem_usage=True)

device = 'cuda:0'
model = model.cuda(device)
model.esm = model.esm.half()
model.trunk.set_chunk_size(64)
torch.backends.cuda.matmul.allow_tf32 = True

def convert_outputs_to_pdb(outputs):
    final_atom_positions = atom14_to_atom37(outputs["positions"][-1], outputs)
    outputs = {k: v.to("cpu").numpy() for k, v in outputs.items()}
    final_atom_positions = final_atom_positions.cpu().numpy()
    final_atom_mask = outputs["atom37_atom_exists"]
    pdbs = []
    for i in range(outputs["aatype"].shape[0]):
        aa = outputs["aatype"][i]
        pred_pos = final_atom_positions[i]
        mask = final_atom_mask[i]
        resid = outputs["residue_index"][i] + 1
        pred = OFProtein(
            aatype=aa,
            atom_positions=pred_pos,
            atom_mask=mask,
            residue_index=resid,
            b_factors=outputs["plddt"][i],
            chain_index=outputs["chain_index"][i] if "chain_index" in outputs else None,
        )
        pdbs.append(to_pdb(pred))
    return pdbs

def calculate_ca_rmsd(pdb_file1, pdb_file2):
    parser = PDBParser(QUIET=True)

    structure1 = parser.get_structure("Protein1", pdb_file1)
    structure2 = parser.get_structure("Protein2", pdb_file2)

    ca_atoms1 = [atom for atom in structure1.get_atoms() if atom.get_name() == "CA"]
    ca_atoms2 = [atom for atom in structure2.get_atoms() if atom.get_name() == "CA"]

    super_imposer = Superimposer()
    super_imposer.set_atoms(ca_atoms1, ca_atoms2)
    super_imposer.apply(structure2.get_atoms())
    rmsd = super_imposer.rms
    return rmsd

def process_sequences(seq_list, pdb_file):
    lowest_rmsd = float('inf')
    lowest_rmsd_data = None
    out_ss_path = "./output"

    if not os.path.exists(out_ss_path):
        os.mkdir(out_ss_path)

    for test_protein in seq_list:
        data = {}
        tokenized_input = tokenizer([test_protein], return_tensors="pt", add_special_tokens=False)['input_ids']
        tokenized_input = tokenized_input.cuda(device)

        with torch.no_grad():
            output = model(tokenized_input)

        data['out'] = output
        data["plddt"] = torch.mean(output['plddt']).item()
        data['pae'] = torch.mean(output['predicted_aligned_error']).item()

        pdb_data = convert_outputs_to_pdb(output)
        tmp_pdb_file = os.path.join(out_ss_path, "TMP.pdb")

        with open(tmp_pdb_file, 'w') as file:
            for line in pdb_data:
                file.write(line)

        data['rmsd'] = calculate_ca_rmsd(tmp_pdb_file, pdb_file)
        print(f'Sequence: {test_protein}, plddt: {data["plddt"]}, PAE: {data["pae"]}, RMSD: {data["rmsd"]}')

        if data['rmsd'] < lowest_rmsd:
            lowest_rmsd = data['rmsd']
            lowest_rmsd_data = data

    if lowest_rmsd_data is not None:
        print(f'Lowest RMSD: {lowest_rmsd}')
        best_pdb_data = convert_outputs_to_pdb(lowest_rmsd_data['out'])
        best_pdb_file = os.path.join(out_ss_path, "best_structure.pdb")

        with open(best_pdb_file, 'w') as file:
            for line in best_pdb_data:
                file.write(line)

        original_dict = lowest_rmsd_data
        key_to_exclude = 'out'
        data_out = {k: v for k, v in original_dict.items() if k != key_to_exclude}

        with open(os.path.join(out_ss_path, "best_structure_data.pkl"), 'wb') as f:
            pickle.dump(data_out, f)

        return lowest_rmsd, best_pdb_file, data_out

    return None, None, None



lowest_rmsd, best_pdb_file, best_data = process_sequences(seq_list, pdb_file)
if lowest_rmsd is not None:
    print(f"Lowest RMSD: {lowest_rmsd}, Best PDB file: {best_pdb_file}")
else:
    print("No valid result found.")


In [None]:
#@title # Heterodimer Design Prep
#@markdown Design a set of heterodimeric proteins with two chains making a complex. The settings are excatly the ones used in the manuscript to design the heterodimer binders.
LENGTH1 = 100 #@param {type:"integer"}
LENGTH2 = 100 #@param {type:"integer"}

#@markdown ProteinMPNN Settings
use_solubleMPNN = True #@param {type:"boolean"}
#@markdown


from colabdesign.af.alphafold.common import residue_constants
import jax
import jax.numpy as jnp

def hd_loss(inputs, outputs):
  positions = outputs["structure_module"]["final_atom_positions"]
  ca1 = positions[:LENGTH1, residue_constants.atom_order["CA"]]
  center1 = ca1.mean(0)
  rg1 = jnp.sqrt(jnp.square(ca1 - center1).sum(-1).mean() + 1e-8)
  rg_th = 2.38 * ca1.shape[0] ** 0.365
  rg1 = jax.nn.elu(rg1 - rg_th)


  ca2 = positions[LENGTH2:, residue_constants.atom_order["CA"]]
  center2 = ca2.mean(0)
  rg2 = jnp.sqrt(jnp.square(ca2 - center2).sum(-1).mean() + 1e-8)
  rg_th = 2.38 * ca2.shape[0] ** 0.365
  rg2 = jax.nn.elu(rg2 - rg_th)



  return {"hd":rg1+rg2}

total_length = LENGTH1 + LENGTH2
clear_mem()
af_model = mk_afdesign_model(protocol="hallucination", loss_callback=hd_loss)
af_model.prep_inputs(length=total_length)
af_model._inputs['residue_index'][LENGTH1:] = np.arange(LENGTH2) + 50 + LENGTH1
# add extra losses
af_model.restart(mode=["gumbel", "soft"])
af_model.opt["weights"]["hd"] = 0.1
af_model.opt["weights"]['plddt'] = 1.0
af_model.opt["weights"]['pae'] = 1.0
af_model.opt["weights"]['helix'] = -0.5
print("weights", af_model.opt["weights"])
print('Starting up and compiling JAX model....')


In [None]:
#@title # Run Design
af_model.design_logits(100)
af_model.save_pdb("Heterodimer.pdb")

In [None]:
af_model.save_pdb("Heterodimer.pdb")
af_model.plot_pdb()

In [None]:
#@title # Design Sequence using Homooligomer Filter
#@markdown We first test if the two protomers are predictd to fold into a high confidence protein on their own, removing proteins that are not likely to be expressed on their own. Then we predict the heterodimer using the AF multimer model. Generally the AF multimer model has a hard time predicting de novo designed proteins. This is why we use templates and remove any interchain information. Finally we predict each individual protomer with a copy of itself, testing for homooligomerisation.


file_path ="Heterodimer.pdb"

folder_path = "/content/"

######## make A - B chain file

from Bio.PDB import PDBParser, PDBIO, Chain

# Set the input and output PDB file names
input_pdb_file = file_path
if not os.path.exists(os.path.join(folder_path, 'AB')):
    os.mkdir(os.path.join(folder_path, 'AB'))
output_pdb_file = os.path.join(folder_path, 'AB',"Heterodimer.pdb")

# Create a PDB parser and read the input PDB file
parser = PDBParser()
structure = parser.get_structure("input_structure", input_pdb_file)

# Find the initial chain id
initial_chain_id = None
for chain in structure[0]:
    initial_chain_id = chain.get_id()
    break

# Create new chains A and B
chain_A = Chain.Chain("A")
chain_B = Chain.Chain("B")

# Iterate over the residues in the original chain
for residue in structure[0][initial_chain_id]:
    res_id = residue.get_id()[1]

    # Add residues 1-200 to chain A
    if 1 <= res_id <= 100:
        chain_A.add(residue.copy())

    # Add residues 201-400 to chain B
    elif 151 <= res_id <= 450:
        chain_B.add(residue.copy())

# Remove the existing chain
for model in structure:
    model.detach_child(initial_chain_id)

# Add the new chains to the model
structure[0].add(chain_A)
structure[0].add(chain_B)

# Save the modified structure to a new PDB file
io = PDBIO()
io.set_structure(structure)
io.save(output_pdb_file)


clear_mem()
he_model = mk_afdesign_model(protocol="fixbb", use_templates=True, use_multimer=True)


ho_model = mk_afdesign_model(protocol="hallucination")
ho_model.prep_inputs(length=LENGTH1, copies=2)

ho_model.set_weights(i_pae=1.0)
s_model = mk_afdesign_model(protocol="hallucination")
s_model.prep_inputs(length=LENGTH2)
mpnn_model = mk_mpnn_model(weights="soluble")


mpnn_model.prep_inputs(pdb_filename=output_pdb_file, chain='A,B',rm_aa="C")
samples = mpnn_model.sample_parallel(8)

he_model.prep_inputs(pdb_filename=output_pdb_file, chain='A,B',rm_template_ic=True)
he_model._inputs['residue_index'][LENGTH1:] = np.arange(LENGTH2) + 50 + LENGTH1

k = 0
for seq in samples['seq']:
    print('Predicting Protomer 1...')
    s_model.predict(seq=seq[:LENGTH1], num_recycles=3)
    plddt1 = s_model.aux['losses']['plddt']
    print('Predicting Protomer 2...')
    s_model.predict(seq=seq[LENGTH1+1:], num_recycles=3)
    plddt2 = s_model.aux['losses']['plddt']
    k = k + 1
    if plddt1 < 0.20 and plddt2 < 0.20:
        print('Passed Protomer Check! Predicting Heterodimer...')
        he_model.predict(seq=''.join([seq[:LENGTH1], seq[LENGTH1+1:]]), num_recycles=3)

        if he_model.aux['losses']['plddt'] < 0.15 and he_model.aux['losses']['rmsd'] < 2.0:
            print('Passed Heterodimer Check! Predicting Homodimer 1...')
            ho_model.predict(seq=seq[:LENGTH1],num_recycles=3)
            print('Predicting Homodimer 2...')
            ipae1 = ho_model.aux['losses']['i_pae']
            ho_model.predict(seq=seq[LENGTH1+1:],num_recycles=3)
            ipae2 = ho_model.aux['losses']['i_pae']
            if ipae1 > 0.8 and ipae2 > 0.8:
              print('Passed Homodimer check!')
              he_model.save_pdb(f'Heterodimer_seq_{k}.pdb')



In [None]:
def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"

def add_rg_loss(self, weight=0.1):
  '''add radius of gyration loss'''
  def loss_fn(inputs, outputs):
    xyz = outputs["structure_module"]
    ca = xyz["final_atom_positions"][:,residue_constants.atom_order["CA"]]

    ca = ca[-self._binder_len:]

    rg = jnp.sqrt(jnp.square(ca - ca.mean(0)).sum(-1).mean() + 1e-8)
    rg_th = 2.38 * ca.shape[0] ** 0.365
    rg = jax.nn.elu(rg - rg_th)
    return {"rg":rg}
  self._callbacks["model"]["loss"].append(loss_fn)
  self.opt["weights"]["rg"] = weight



#@title # Binder Design
#@markdown For a given length, generate/hallucinate a protein sequence that AlphaFold thinks folds into a well structured protein (high plddt, low pae, many contacts).
LENGTH = 100 #@param {type:"integer"}
binder_pdb = '5NGV' #@param {type:"string"}
binder_chain ='A' #@param {type:"string"}
hotspot ='' #@param {type:"string"}
if hotspot == "": hotspot = None
#@markdown ProteinMPNN Settings
use_solubleMPNN = True #@param {type:"boolean"}
#@markdown

clear_mem()
af_model = mk_afdesign_model(protocol="binder")
add_rg_loss(af_model)
af_model.prep_inputs(pdb_filename=get_pdb(binder_pdb), chain=binder_chain,hotspot=hotspot, binder_len=LENGTH)


af_model.restart(mode=["gumbel", "soft"])

af_model.opt["weights"]["rg"] = 0.5

af_model.opt["weights"]['helix'] = -0.2
af_model.opt["weights"]['plddt'] = 0.1
af_model.opt["weights"]['pae'] = 0.1
af_model.opt["weights"]['i_pae'] = 0.1
af_model.opt["weights"]['i_con'] = 2.0

print("weights", af_model.opt["weights"])
print('Starting up and compiling JAX model....')


In [None]:
af_model.design_logits(100)
af_model.save_pdb("Binder.pdb")

In [None]:
af_model.plot_pdb()

In [None]:

#@title # Binder Sequence Design with AF Multimer filtering
#@markdown Use this to generate sequences for the binder candidate generated in the previous step

#@markdown First we use the AF2 PTM model to predict the binder without receptor, acting as a fast pre filter. Then we use the AF Multimer model to predict the Receptor Binder complex. Again we use a template for the binder to help AF Multimer predicting the de novo designed protein

binder_model = mk_afdesign_model(protocol="binder",use_multimer=True,use_initial_guess=True)
hall_model = mk_afdesign_model(protocol="fixbb")


binder_model.set_weights(i_pae=1.0)


mpnn_model = mk_mpnn_model(weights="soluble")
mpnn_model.prep_inputs(pdb_filename="Binder.pdb", chain='A,B', fix_pos='A',rm_aa="C")

samples = mpnn_model.sample_parallel(8,temperature=0.01)
hall_model.prep_inputs(pdb_filename="Binder.pdb", chain='B')
binder_model.prep_inputs(pdb_filename="Binder.pdb", chain='A', binder_chain='B',use_binder_template=True,rm_template_ic=True)
k=0
for seq in samples['seq']:
    print("Predicting binder only")
    hall_model.predict(seq=seq[-LENGTH:], num_recycles=3)
    if hall_model.aux['losses']['rmsd'] < 2.0 :
        print("Passed! Predicting binder with receptor using AF Multimer")
        binder_model.predict(seq=seq[-LENGTH:], num_recycles=3)
        plddt1 = binder_model.aux['losses']['plddt']
        i_pae = binder_model.aux['losses']['i_pae']
        if plddt1 < 0.15 and i_pae < 0.4:
           print(f"Passed! Final I_PAE is {i_pae*31}")
           binder_model.save_pdb(f'Binder_seq_{k}.pdb')
           binder_model.plot_pdb()

    k = k + 1

In [None]:
#@title # Site scaffolding example
#@markdown This cell provides the code to perform the site scaffolding in bulk.
#@markdown Just go to the commented section with names, contigs and length to insert the desired PDB identifier, contigs and final size and start designing.
#@markdown Num_designs controls how many backbones one designes per PDB file

num_designs = 1 #@param {type:"integer"}


def get_pdb(pdb_code=""):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"



from colabdesign import mk_afdesign_model, clear_mem
import contextlib

from colabdesign.af.alphafold.common import residue_constants
import jax
import jax.numpy as jnp
import pickle
from colabdesign.mpnn import mk_mpnn_model

import re
import os


#Add the names of the PDB files for the scaffolding problem here
names = [
    "1PRW"
]
print(len(names))
#Add the design contigs here

inputs = [
    "5-20,A16-35,10-25,A52-71,5-20"
]
#Add the total length here. We only use the maximum length specified
lengths = [
    "60-105"
]


def rg_loss(inputs, outputs):
    positions = outputs["structure_module"]["final_atom_positions"]
    ca = positions[::5, residue_constants.atom_order["CA"]]
    center = ca.mean(0)
    rg = jnp.sqrt(jnp.square(ca - center).sum(-1).mean() + 1e-8)
    rg_th = 2.38 * ca.shape[0] ** 0.365
    rg = jax.nn.elu(rg - rg_th)
    return {"rg": rg}


clear_mem()

for _name, _input, _length in zip(
    names, inputs, lengths
):
    print(f"Starting on {_name}")
    _input = _input.replace(" ", "")
    __name = _name.split("_")[0]
    model = mk_afdesign_model(
        protocol="partial"
    )
    wire_loop_repr = ["l" if re.search("[A-Z]", x) else "w" for x in _input.split(",")]

    _lengths = []
    for _id, rep in zip(wire_loop_repr, _input.split(",")):
        if "-" in rep:  # loop or range
            if _id == "l":  # loop
                rep = rep[1:]
                _len = int(rep.split("-")[1]) - int(rep.split("-")[0]) + 1
            else:  # range
                _len = int(rep.split("-")[1])
        else:
            if _id == "l":
                rep = rep[1:]
            _len = int(1)
        _lengths.append(_len)

    overall_length = sum(_lengths)
    print(overall_length)
    old_pos = list(filter(lambda x: re.search("[A-Z]", x), _input.split(",")))
    order = list(range(len(old_pos)))
    old_pos = ",".join(old_pos)
    wires = list(filter(lambda x: not re.search("[A-Z]", x), _input.split(",")))
    wires = [
        int(wire) if "-" not in wire else int(wire.split("-")[1]) for wire in wires
    ]
    offset = wires[0] if not wire_loop_repr[0] == "l" else 0
    if wire_loop_repr[0] == "w":
        wires = wires[1:]
    if wire_loop_repr[-1] == "w":
        wires = wires[:-1]

    chain = re.findall("[A-Z]", _input)
    chain = list(set(chain))
    assert len(chain) == 1
    chain = chain[0]
    if "-" in _length:
        _length = _length.split("-")[1]
    _length = int(_length)
    if _length < overall_length:
        _length = overall_length

    debug = False
    if debug:
        print("chain  " + str(chain))
        print("old_pos  " + str(old_pos))
        print("wires  " + str(wires))
        print("offset  " + str(offset))
        print("_length  " + str(_length))
        print("order  " + str(order))
    print(_name)
    pdb_file = get_pdb(_name)
    model.prep_inputs(
        pdb_file,
        chain=chain,
        pos=old_pos,
        length=_length,
        fix_seq=True,
    )


    model.rewire(
        order=order,  # set order of segments
        loops=wires,  # change loop length inbetween segments
        offset=offset,
    )  # essentially loop length at the N term

    print("   Starting up and compiling JAX model....")

    for i in range(num_designs):
        print(f"      Iteration {i} of 100")
        model.restart(mode=["gumbel", "soft"], rm_aa="C")
        model.opt["weights"]["rg"] = 0.1
        model.opt["weights"]["dgram_cce"] = 2.0
        model.opt["weights"]["plddt"] = 0.1
        model.opt["weights"]["pae"] = 0.1
        model.opt["weights"]["rmsd"] = 1.0
        model.opt["weights"]['sc_rmsd'] = 1.0
        #           model.opt["weights"]['fape'] = 1.0

        model.design_logits(190)
        model.design_logits(10, save_best=True)
        outfile = f"out_sc/{_name}_resesigned/{_name}_redesigned_{i}.pdb"
        os.makedirs(os.path.dirname(outfile), exist_ok=True)
        model.save_pdb(outfile)
    mpnn_model = mk_mpnn_model()

    p = (
        []
    )  # [homo if not n in _interfaceFixturesIndexSecChain else hetero for n, (homo, hetero) in enumerate(zip(list(ho2), list(he[-len(ho2):])))]
    for k in model.opt["pos"]:
        p.append(str(k + 1))  # Might be wrong
        p.append(",")
    posf = "".join(p[:-1])

    repredictionModel = mk_afdesign_model(
        protocol="fixbb", use_templates=False
    )
    os.makedirs(os.path.dirname('out_sc_Redesigned/'), exist_ok=True)

    for j in range(num_designs):
        print(f"      Reprediction Iteration {j} of 100")
        repredictionModel.prep_inputs(
            f"out_sc/{_name}_resesigned/{_name}_redesigned_{j}.pdb"
        )

        mpnn_model.prep_inputs(
            pdb_filename=f"out_sc/{_name}_resesigned/{_name}_redesigned_{j}.pdb",
            chain="A",
            fix_pos=posf,
            rm_aa="C",
        )
        out = mpnn_model.sample(num=1, batch=8, temperature=0.1)

        for n, i in enumerate(out["seq"]):

            repredictionModel.predict(seq=i, num_recycles=3)
            if (
                repredictionModel.aux["log"]["rmsd"] < 2.0
                and repredictionModel.aux["log"]["plddt"] > 0.85
            ):
                filename = f'out_sc_Redesigned/{_name}_resesigned/{_name}_redesigned-{j}_num-{n}_rmsd-{int(repredictionModel.aux["log"]["rmsd"]*100)}.pdb'
                os.makedirs(os.path.dirname(filename), exist_ok=True)
                repredictionModel.save_pdb(filename)

for _name, _input, _length in zip(
    names, inputs, lengths
):
    print(f"Starting on {_name}")
    clear_mem()

    _input = _input.replace(" ", "")
    __name = _name.split("_")[0]

    test_model = mk_afdesign_model(protocol='fixbb')
    model = mk_afdesign_model(
        protocol="partial", use_templates=False
    )  # set True to constrain positions using template input
    # define positions we want to constrain (input PDB numbering)

    wire_loop_repr = ["l" if re.search("[A-Z]", x) else "w" for x in _input.split(",")]

    _lengths = []
    for _id, rep in zip(wire_loop_repr, _input.split(",")):
        if "-" in rep:  # loop or range
            if _id == "l":  # loop
                rep = rep[1:]
                _len = int(rep.split("-")[1]) - int(rep.split("-")[0]) + 1
            else:  # range
                _len = int(rep.split("-")[1])
        else:
            if _id == "l":
                rep = rep[1:]
            _len = 1
        _lengths.append(_len)

    overall_length = sum(_lengths)

    old_pos = list(filter(lambda x: re.search("[A-Z]", x), _input.split(",")))
    order = list(range(len(old_pos)))
    old_pos = ",".join(old_pos)
    wires = list(filter(lambda x: not re.search("[A-Z]", x), _input.split(",")))
    wires = [
        int(wire) if "-" not in wire else int(wire.split("-")[1]) for wire in wires
    ]
    offset = wires[0] if not wire_loop_repr[0] == "l" else 0
    if wire_loop_repr[0] == "w":
        wires = wires[1:]
    if wire_loop_repr[-1] == "w":
        wires = wires[:-1]

    chain = re.findall("[A-Z]", _input)
    chain = list(set(chain))
    assert len(chain) == 1
    chain = chain[0]
    if "-" in _length:
        _length = _length.split("-")[1]
    _length = int(_length)
    if _length < overall_length:
        _length = overall_length


    print(_name)
    pdb_file = get_pdb(_name)



    model.prep_inputs(
        pdb_file,
        chain=chain,
        pos=old_pos,  # define positions to contrain
        length=_length,  # define if the desired length is different from input PDB
        fix_seq=True,
    )  # set True to constrain the sequence

    # set positions (if different from PDB)
    # reorder the segments,
    model.rewire(
        order=order,  # set order of segments
        loops=wires,  # change loop length inbetween segments
        offset=offset,
    )  # essentially loop length at the N term

    in_files = os.listdir(f'out_sc_Redesigned/{_name}_resesigned/')
    if not os.path.exists(f'out_sc_Redesigned/{_name}_resesigned/out/'):
        os.mkdir(f'out_sc_Redesigned/{_name}_resesigned/out/')
    for ii in in_files:
        if ii[-1] == 'b':
            test_model.prep_inputs(pdb_filename=f'out_sc_Redesigned/{_name}_resesigned/{ii}')
            seq = test_model._inputs['batch']["aatype"]
            #print(seq)
            model.predict(seq=seq, num_recycles=3)
            if model.aux["losses"]["rmsd"] < 1.0:
                model.save_pdb(f'out_sc_Redesigned/{_name}_resesigned/out/{ii}')
                with open(f'out_sc_Redesigned/{_name}_resesigned/out/{ii[:-4]}_data.pkl', 'wb') as f:
                    pickle.dump(model.aux["losses"]["rmsd"], f)

