<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/v1.1.1/af/examples/use_esm_1b_bias.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

# Gather inputs

In [None]:
import requests, re
from google.colab import files

def get_uniprot_seq(uid):
  url = f'https://rest.uniprot.org/uniprotkb/stream?compressed=false&format=fasta&query={uid}'
  sequence = "".join(re.split(r'\n(?=>)', requests.get(url).text)[0].split("\n")[1:])
  return sequence

def get_pdb(pdb_code="",alphafold_model=False):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  else:
    if alphafold_model:
      os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
      return f"AF-{pdb_code}-F1-model_v3.pdb"      
    else:
      os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
      return f"{pdb_code}.pdb"

In [None]:
UNIPROT = "P0A6A8"
SEQUENCE = get_uniprot_seq(UNIPROT)
PDB_FILENAME = get_pdb(UNIPROT, alphafold_model=True)

#ESM_1b
use logits from ESM_1b as prior to AfDesign

In [None]:
!pip -q install fair-esm
import esm
import gc
import torch
import numpy as np

In [None]:
model, alphabet = esm.pretrained.esm1b_t33_650M_UR50S()

# run model on GPU if available
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
model = model.to(device)
model = model.eval()
model.args.token_dropout = False

In [None]:
def get_bias_from_esm(seq, p=None):
  '''p=None; number of calculation done in parallel (increase if you have more gpu-memory)'''

  # map esm-alphabet to standard-alphabet
  tmp_a2n = {a:n for n,a in enumerate(alphabet.all_toks[4:24])}
  tmp_aa_map = np.array([tmp_a2n[a] for a in "ARNDCQEGHILKMFPSTWYV"])

  x,ln = alphabet.get_batch_converter()([(None,seq)])[-1],len(seq)
  if p is None: p = ln
  with torch.no_grad():
    f = lambda x: model(x)["logits"][:,1:(ln+1),4:24]
    logits = np.zeros((ln,20))
    for n in range(0,ln,p):
      m = min(n+p,ln)
      x_h = torch.tile(torch.clone(x),[m-n,1])
      for i in range(m-n):
        x_h[i,n+i+1] = alphabet.mask_idx
      fx_h = f(x_h.to(device))
      for i in range(m-n):
        logits[n+i] = fx_h[i,n+i].cpu().numpy()
  
    return logits[:,tmp_aa_map]

In [None]:
# get bias
seq = SEQUENCE
bias = get_bias_from_esm(seq)
np.savetxt("bias.txt",bias)

In [None]:
import matplotlib.pyplot as plt
plt.imshow(bias.T,cmap="bwr_r",vmin=-10,vmax=10)

In [None]:
# clear GPU memory
del model
gc.collect()
torch.cuda.empty_cache()

#AfDesign 


In [None]:
#@title setup afdesign
%%time
import os
if not os.path.isdir("params"):
  # get code
  os.system("pip -q install git+https://github.com/sokrypton/ColabDesign.git@v1.1.1")
  # for debugging
  os.system("ln -s /usr/local/lib/python3.*/dist-packages/colabdesign colabdesign")
  # download params
  os.system("mkdir params")
  os.system("apt-get install aria2 -qq")
  os.system("aria2c -q -x 16 https://storage.googleapis.com/alphafold/alphafold_params_2022-12-06.tar")
  os.system("tar -xf alphafold_params_2022-12-06.tar -C params")

import warnings
warnings.simplefilter(action='ignore', category=FutureWarning)

import os
from colabdesign import mk_afdesign_model, clear_mem
from colabdesign.af.alphafold.common import residue_constants
from IPython.display import HTML
import numpy as np



In [None]:
clear_mem()
model = mk_afdesign_model(protocol="fixbb",
                          use_templates=False) # set True to constrain structure
model.prep_inputs(PDB_FILENAME, chain="A")
print("length",  model._len)

In [None]:
import matplotlib.pyplot as plt
bias = np.loadtxt("bias.txt")
plt.imshow(bias.T,cmap="bwr_r",vmin=-10,vmax=10)

In [None]:
model.restart()
model.set_seq(bias=bias)
model.design_3stage(50,50,10)

In [None]:
HTML(model.animate())

In [None]:
model.get_seqs()

In [None]:
model.save_pdb(f"{model.protocol}.pdb")
model.plot_pdb()