<a href="https://colab.research.google.com/github/sokrypton/ColabDesign/blob/main/rf/examples/diffusion_foldcond.ipynb" target="_parent"><img src="https://colab.research.google.com/assets/colab-badge.svg" alt="Open In Colab"/></a>

**RFdiffusion** - conditional fold generation
---

**<font color="red">NOTE</font>** This notebook is in development, we are still working on adding all the options from the [manuscript](https://www.biorxiv.org/content/10.1101/2022.12.09.519842v2)

**instructions**:
1. select mode
2. enter info, hit the ▶️ button
 - **RFdiffusion** takes ~1min to setup, next time you run this cell it will take seconds!

3. modify the blueprint
 - use diagonal to define the SSEs (`H:helix E:sheet C:coil ?:undefined`)
 - use off-diagonal to define interactions (`0:no_contact 1:contact ?:undefined`)
 - use the textbox in the last column to define the length of each SSE
 - define the buffer length (`buff_length`) between SSEs

In [None]:
#@title Generate blueprint for **RFdiffusion**

name = "test"
blueprint_mode = "manual" #@param ["manual", "automated"]
run_mode = "unconditional"

#@markdown ---
#@markdown **Manual** blueprint (define number of secondary structure `elements` (SSE))
elements = 5 #@param ["1", "2", "3", "4", "5", "6", "7", "8", "9", "10", "11", "12", "13", "14", "15", "16", "17", "18", "19", "20"] {type:"raw"}
#@markdown ---
#@markdown **Automated** blueprint (from input PDB)
pdb = "6MRR" #@param {type:"string"}
chain = "A" #@param {type:"string"}
trim_loops = True #@param {type:"boolean"}
if chain == "": chain = None

import os, time, sys

######################################################################
# SETUP RFDIFFUSION
######################################################################
if not os.path.isdir("RFdiffusion"):
  print("installing RFdiffusion...")
  # send param download into background
  os.system("apt-get install aria2")
  os.system("(\
    aria2c -q -x 16 https://files.ipd.uw.edu/krypton/schedules.zip; \
    aria2c -q -x 16 http://files.ipd.uw.edu/pub/RFdiffusion/60f09a193fb5e5ccdc4980417708dbab/Complex_Fold_base_ckpt.pt; \
  )&")

  # install RFdiffusion
  os.system("git clone https://github.com/sokrypton/RFdiffusion.git")
  os.system("pip install jedi omegaconf hydra-core icecream pyrsistent pynvml decorator")
  os.system("pip install git+https://github.com/NVIDIA/dllogger#egg=dllogger")
  # 17Mar2024: adding --no-dependencies to avoid installing nvidia-cuda-* dependencies
  # 25Aug2025: updating dgi install to work with latest pytorch
  os.system("pip install --no-dependencies dgl -f https://data.dgl.ai/wheels/torch-2.4/cu124/repo.html")
  os.system("pip install --no-dependencies e3nn==0.5.5 opt_einsum_fx")
  os.system("cd RFdiffusion/env/SE3Transformer; pip install .")

  # extras
  os.system("pip -q install py3Dmol pydssp")
  os.system("wget -qnc https://raw.githubusercontent.com/sokrypton/ColabDesign/main/colabdesign/rf/blueprint.js")
  os.system("wget -qnc https://raw.githubusercontent.com/sokrypton/ColabDesign/main/colabdesign/rf/blueprint.css")

if not os.path.isdir("RFdiffusion/models"):
  print("downloading RFdiffusion params...")
  os.system("mkdir RFdiffusion/models")
  models = ["Complex_Fold_base_ckpt.pt"]
  for m in models:
    while os.path.isfile(f"{m}.aria2"):
      time.sleep(5)
  os.system(f"mv {' '.join(models)} RFdiffusion/models")
  os.system("unzip schedules.zip; rm schedules.zip")
  print("----------------------------------")

if 'RFdiffusion' not in sys.path:
  os.environ["DGLBACKEND"] = "pytorch"
  sys.path.append('RFdiffusion')
######################################################################

from IPython.display import display
import ipywidgets as widgets
import torch
import random, string, re
import numpy as np
import subprocess
import matplotlib.pyplot as plt
import py3Dmol
from google.colab import files, output

from string import ascii_uppercase, ascii_lowercase
alphabet_list = list(ascii_uppercase+ascii_lowercase)

def get_pdb(pdb_code=None):
  if pdb_code is None or pdb_code == "":
    upload_dict = files.upload()
    pdb_string = upload_dict[list(upload_dict.keys())[0]]
    with open("tmp.pdb","wb") as out: out.write(pdb_string)
    return "tmp.pdb"
  elif os.path.isfile(pdb_code):
    return pdb_code
  elif len(pdb_code) == 4:
    os.system(f"wget -qnc https://files.rcsb.org/view/{pdb_code}.pdb")
    return f"{pdb_code}.pdb"
  else:
    os.system(f"wget -qnc https://alphafold.ebi.ac.uk/files/AF-{pdb_code}-F1-model_v3.pdb")
    return f"AF-{pdb_code}-F1-model_v3.pdb"

def pdb_to_string(pdb_file, chains=None, models=[1]):
  '''read pdb file and return as string'''

  MODRES = {'MSE':'MET','MLY':'LYS','FME':'MET','HYP':'PRO',
            'TPO':'THR','CSO':'CYS','SEP':'SER','M3L':'LYS',
            'HSK':'HIS','SAC':'SER','PCA':'GLU','DAL':'ALA',
            'CME':'CYS','CSD':'CYS','OCS':'CYS','DPR':'PRO',
            'B3K':'LYS','ALY':'LYS','YCM':'CYS','MLZ':'LYS',
            '4BF':'TYR','KCX':'LYS','B3E':'GLU','B3D':'ASP',
            'HZP':'PRO','CSX':'CYS','BAL':'ALA','HIC':'HIS',
            'DBZ':'ALA','DCY':'CYS','DVA':'VAL','NLE':'LEU',
            'SMC':'CYS','AGM':'ARG','B3A':'ALA','DAS':'ASP',
            'DLY':'LYS','DSN':'SER','DTH':'THR','GL3':'GLY',
            'HY3':'PRO','LLP':'LYS','MGN':'GLN','MHS':'HIS',
            'TRQ':'TRP','B3Y':'TYR','PHI':'PHE','PTR':'TYR',
            'TYS':'TYR','IAS':'ASP','GPL':'LYS','KYN':'TRP',
            'CSD':'CYS','SEC':'CYS'}
  restype_1to3 = {'A': 'ALA','R': 'ARG','N': 'ASN',
                  'D': 'ASP','C': 'CYS','Q': 'GLN',
                  'E': 'GLU','G': 'GLY','H': 'HIS',
                  'I': 'ILE','L': 'LEU','K': 'LYS',
                  'M': 'MET','F': 'PHE','P': 'PRO',
                  'S': 'SER','T': 'THR','W': 'TRP',
                  'Y': 'TYR','V': 'VAL'}

  restype_3to1 = {v: k for k, v in restype_1to3.items()}

  if chains is not None:
    if "," in chains: chains = chains.split(",")
    if not isinstance(chains,list): chains = [chains]
  if models is not None:
    if not isinstance(models,list): models = [models]

  modres = {**MODRES}
  lines = []
  seen = []
  model = 1
  for line in open(pdb_file,"rb"):
    line = line.decode("utf-8","ignore").rstrip()
    if line[:5] == "MODEL":
      model = int(line[5:])
    if models is None or model in models:
      if line[:6] == "MODRES":
        k = line[12:15]
        v = line[24:27]
        if k not in modres and v in restype_3to1:
          modres[k] = v
      if line[:6] == "HETATM":
        k = line[17:20]
        if k in modres:
          line = "ATOM  "+line[6:17]+modres[k]+line[20:]
      if line[:4] == "ATOM":
        chain = line[21:22]
        if chains is None or chain in chains:
          atom = line[12:12+4].strip()
          resi = line[17:17+3]
          resn = line[22:22+5].strip()
          if resn[-1].isalpha(): # alternative atom
            resn = resn[:-1]
            line = line[:26]+" "+line[27:]
          key = f"{model}_{chain}_{resn}_{resi}_{atom}"
          if key not in seen: # skip alternative placements
            lines.append(line)
            seen.append(key)
      if line[:5] == "MODEL" or line[:3] == "TER" or line[:6] == "ENDMDL":
        lines.append(line)
  return "\n".join(lines)

def from_pdb(pdb_code=None, chains=None, trim_loops=False,
             mask_contacts=False, return_pdb_str=False):

  import pydssp
  def process(secondary_structure, contact_map):
    secondary_structure = np.array(secondary_structure)
    # Find the start and end indices of the continuous secondary structure elements
    sse_start,sse_end = [],[]
    for i, current_element in enumerate(secondary_structure):
      if current_element in ["H", "E", "C"]:
        if i == 0 or secondary_structure[i-1] != current_element:
          sse_start.append(i)
        if i == len(secondary_structure) - 1 or secondary_structure[i+1] != current_element:
          sse_end.append(i)

    sse_types = secondary_structure[sse_start]
    sse_lengths = np.array(sse_end) - np.array(sse_start) + 1
    num_sse = len(sse_lengths)
    reduced_contact_map = np.full((num_sse, num_sse), '0', dtype=object)
    np.fill_diagonal(reduced_contact_map, sse_types)

    for i in range(num_sse):
      for j in range(num_sse):
        if i != j and sse_types[i] != "C" and sse_types[j] != "C":
          interaction_mask = np.any(contact_map[sse_start[i]:sse_end[i]+1, sse_start[j]:sse_end[j]+1])
          reduced_contact_map[i, j] = str(interaction_mask.astype(int))
          if mask_contacts and reduced_contact_map[i, j] == "1":
            reduced_contact_map[i, j] = "?"


    return {"txt":sse_lengths, "adj":reduced_contact_map}

  def coord_2_cb(coord):
    N,Ca,C = coord[:,0],coord[:,1],coord[:,2]
    # recreate Cb given N,Ca,C
    b = Ca - N
    c = C - Ca
    a = np.cross(b, c)
    Cb = -0.57910144*a + 0.5689693*b - 0.5441217*c + Ca
    return Cb
  pdb_filename = get_pdb(pdb_code)
  pdb_str = pdb_to_string(pdb_filename, chains=chains)
  coord = pydssp.read_pdbtext(pdb_str)

  ss = pydssp.assign(coord)

  # filter single length sse
  for i in range(len(ss)):
    if ss[i] in ["H","E"]:
      if (i == (len(ss)-1) or ss[i] != ss[i+1]) and (i == 0 or ss[i] != ss[i-1]):
        ss[i] = "-"

  if not trim_loops:
    ss = [("C" if s == "-" else s) for s in ss]
  cb = coord_2_cb(coord)
  con = np.sqrt(np.square(cb[:,None] - cb[None,:]).sum(-1)) < 6.0
  out = process(ss, con)
  if return_pdb_str:
    out["pdb_str"] = pdb_str
  return out

def get_adj_ss(adj, txt, buff=0, mask_contacts=False):
  # select non-zero elements
  idx = []
  for i in range(len(adj)):
    if txt[i] > 0:
      idx.append(i)

  L = (len(idx) + 1) * buff + sum(txt)
  full_adj = np.full((L,L),2)
  full_sse = np.full((L,),3)
  n = buff
  for i in idx:
    ss = {"H":0, "E":1, "C":2, "?":3}[adj[i][i]]
    full_sse[n:n+txt[i]] = ss
    m = buff
    for j in idx:
      k = str(adj[i][j])
      if i == j:
        val = {"H":0,"E":0,"C":0,"?":2}[k]
      else:
        if mask_contacts and k == "1": k = "?"
        val = {"0":0,"1":1,"?":2}[k]
      full_adj[n:n+txt[i],m:m+txt[j]] = val
      m += txt[j] + buff
    n += txt[i] + buff
  return {"adj":full_adj,"sse":full_sse}

class blueprint_gui:

  def _toggle_callback(self, row, col):
    if row == col:
      new_value = {"H":"E","E":"C","C":"?","?":"H"}[self.adj[row][col]]
      self.txt[row] = {"H": 19, "E": 5, "C": 3, "?": 0}[new_value]
      self.adj[row][col] = new_value
      for i in range(self.elements):
        if i != row:
          if new_value == "?":
            self.adj[row][i] = self.adj[i][col] = "?"
          elif self.adj[i][i] != "?" and new_value in ["C","H"]:
            self.adj[row][i] = self.adj[i][col] = '0'
    else:
      if self.adj[row][row] not in ["C","?"] and self.adj[col][col] not in ["C","?"]:
        new_value = {"0":"1","1":"?","?":"0"}[self.adj[row][col]]
        self.adj[row][col] = self.adj[col][row] = new_value

  def _text_callback(self, row, new_value):
    self.txt[row] = int(new_value)

  def _update_callback(self, position, add):
    if position < 0: position = self.elements
    self.elements = self.elements + 1 if add else self.elements - 1
    if self.elements < 0: self.elements = 0
    adj = [['' for _ in range(self.elements)] for _ in range(self.elements)]
    txt = ['' for _ in range(self.elements)]
    for row in range(self.elements):
      old_row = row if row < position else row - 1 if add else row + 1
      if add and row == position:
        txt[row] = 19
      else:
        txt[row] = self.txt[old_row]
      for col in range(self.elements):
        old_col = col if col < position else col - 1 if add else col + 1
        if add and (row == position or col == position):
          if row == col:
            adj[row][col] = 'H'
          else:
            cell = self.adj[old_row][old_row] if col == position else self.adj[old_col][old_col]
            adj[row][col] = cell if cell == "?" else '0'
        else:
          adj[row][col] = self.adj[old_row][old_col]
    self.adj = adj
    self.txt = txt

  def _create_html(self):
    # HTML for initial grid
    html_grid = f'<div class="pos"></div>'
    for row in range(self.elements): html_grid += f'<div class="pos">{row}</div>'
    html_grid += f'<div class="pos"></div>'
    for row in range(self.elements):
      html_grid += f'<div class="pos">{row}</div>'
      for col in range(self.elements):
        value = self.adj[row][col]
        bgcolor = {"H":"red","E":"yellow","C":"lime","?":"lightgray","0":"white","1":"lightblue"}[value]
        if row != col and (self.adj[row][row] in ["?","C"] or self.adj[col][col] in ["?","C"]):
          opacity = 0.1
        else:
          opacity = 1.0
        html_grid += f'<div class="grid-item" id="cell_{row}_{col}" style="background-color:{bgcolor};opacity:{opacity}">{value}</div>'
      html_grid += f'<div><input class="text" type="number" id="cell_{row}" min="0" value="{self.txt[row]}" onchange="textFieldChanged({row}, this)"></div>'

    self.html_code = f"""
    <style>
    {self._CSS}
    .grid-container {{
      display: grid;
      grid-template-columns: repeat({self.elements+2}, 30px);
      gap: 2px;
    }}
    </style>
    <script>{self._JS}</script>
    <label>resize:</label>
    <button id="add" style="width:25px" class="button" onclick="updateGrid(true)">+</button>
    <button id="remove" style="width:25px" class="button" onclick="updateGrid(false)">-</button>
    <input type="number" id="position" min="-1" max="{self.elements}" value="0" class="text">
    <label>(indicate where to +/- an element)</label>
    <div class="grid-container">{html_grid}</div>
    """

class RFdiff_gui(blueprint_gui):

  def __init__(self, elements=5, adj=None, txt=None, buff_length=5, name="test"):
    self.path = self.name = name
    self.input = widgets.Output()
    self.output = widgets.Output()
    self.buff_length = buff_length

    small_button_style = widgets.Layout(width='30px', height='30px', border='2px solid black')
    button_style = widgets.Layout(width='84px', height='35px', border='2px solid black')
    self.buttons = {
        "buff_length": widgets.BoundedIntText(description='buff_length', value=self.buff_length, min=0, max=20),
        "reset":       widgets.Button(description='reset',     layout=button_style),
        "animate":     widgets.Button(description='animate',   layout=button_style),
        "freeze":      widgets.Button(description='freeze',    layout=button_style),
        "download":    widgets.Button(description='download',  layout=button_style),
        "color":       widgets.Dropdown(
                        options=['SSE','pLDDT'],
                        value='SSE',
                        description='color',
                        disabled=False)
    }
    self.buttons["animate"].on_click(self._plot_pdb)
    self.buttons["freeze"].on_click(self._plot_pdb)
    self.buttons["download"].on_click(self._download)
    self.buttons["color"].observe(self._plot_pdb)
    self._plot = {"mode":"freeze","color":"SSE"}

    # prep inputs
    if adj is not None and txt is not None:
      self.elements = len(adj)
      self.adj, self.txt = adj,txt
    else:
      self.elements = elements
      self.adj = [["H" if row == col else "0" for col in range(self.elements)] for row in range(self.elements)]
      self.txt = [19 for _ in range(self.elements)]

    output.register_callback("update_callback", self._update_callback)
    output.register_callback("toggle_callback", self._toggle_callback)
    output.register_callback("text_callback",   self._text_callback)
    self._CSS = open("blueprint.css","r").read()
    self._JS = open("blueprint.js","r").read()

  def _redraw(self):
    with self.input:
      self._create_html()
      self.input.clear_output(wait=True)
      display(
          widgets.VBox([
          widgets.HTML(self.html_code),
          widgets.Label("Options"),
          self.buttons["buff_length"],
        ])
      )

  def display_input(self):
    self._redraw()
    display(self.input)

  def display_output(self):
    display(self.output)

  def _download(self, button):
    os.system(f"zip -r {self.path}.result.zip outputs/{self.path}* outputs/traj/{self.path}*")
    files.download(f"{self.path}.result.zip")

  def _plot_pdb(self, change):
    update = False
    if isinstance(change, widgets.Button):
      self._plot["mode"] = change.description
      update = True
    elif isinstance(change, dict) and change['name'] == 'value':
      widget = change['owner']
      if isinstance(widget, widgets.Dropdown):
        self._plot["color"] = change["new"]
        update = True
    if update:
      view = py3Dmol.view()
      if self._plot["mode"] == "animate":
        pdb = f"outputs/traj/{self.path}_0_pX0_traj.pdb"
        pdb_str = open(pdb,'r').read()
        view.addModelsAsFrames(pdb_str,'pdb')
      else:
        pdb = f"outputs/{self.path}_0.pdb"
        pdb_str = open(pdb,'r').read()
        view.addModel(pdb_str,'pdb')
      if self._plot["color"] == "SSE":
        view.setStyle({"ss":"h","chain":"A"},{'cartoon': {'color':'red'}})
        view.setStyle({"ss":"c","chain":"A"},{'cartoon': {'color':'lime'}})
        view.setStyle({"ss":"s","chain":"A"},{'cartoon': {'color':'yellow'}})
        if self.use_target:
          view.setStyle({"chain":"B"},{'cartoon': {'color':'white'}})
      else:
        view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':0.5,'max':0.9}}})
      view.zoomTo()
      if self._plot["mode"] == "animate":
        view.animate({'loop': 'backAndForth'})
      out = widgets.Output()
      with out: view.show()
      toggle = self.buttons["freeze"] if self._plot["mode"] == "animate" else self.buttons["animate"]
      with self.output:
        self.output.clear_output(wait=True)
        display(widgets.VBox([out, widgets.HBox([toggle, self.buttons["download"], self.buttons["color"]])]))

  def _make_path(self):
    os.makedirs(f"outputs/{self.path}", exist_ok=True)
    while os.path.exists(f"outputs/{self.path}_0.pdb"):
      self.path = self.name + "_" + ''.join(random.choices(string.ascii_lowercase + string.digits, k=5))
      os.makedirs(f"outputs/{self.path}", exist_ok=True)

  def _get_adj_ss(self, mask_contacts=False):
    # get unique path
    full = get_adj_ss(adj=self.adj,
                      txt=self.txt,
                      buff=self.buttons["buff_length"].value,
                      mask_contacts=mask_contacts)
    self._sse = full["sse"]
    self._adj = full["adj"]

    # save results
    loc = [f"outputs/{self.path}/tmp_ss.pt",
           f"outputs/{self.path}/tmp_adj.pt"]
    torch.save(torch.from_numpy(self._sse).float(),loc[0])
    torch.save(torch.from_numpy(self._adj).float(),loc[1])

  def diffuse(self, iterations=50,
             mask_loops=True,
             mask_contacts=False,
             extra_cmd=None):
    self.use_target = use_target
    self._redraw()
    self._make_path()
    self._get_adj_ss(mask_contacts=mask_contacts)
    # run
    with self.output:
      self.output.clear_output()
      cmd = ["./RFdiffusion/run_inference.py",
             "inference.num_designs=1",
             f"inference.output_prefix=outputs/{self.path}",
             "scaffoldguided.scaffoldguided=True",
             f"scaffoldguided.scaffold_dir=outputs/{self.path}",
             f"diffuser.T={iterations}",
             f"scaffoldguided.mask_loops={mask_loops}",
             "inference.dump_pdb=True",
             "inference.dump_pdb_path=/dev/shm"]

      if extra_cmd is not None:
        cmd += extra_cmd

      self.cmd_str = " ".join(cmd)
      self._run(self.cmd_str, iterations)
    self._plot_pdb(self.buttons["freeze"])

  def _run(self, command, steps, num_designs=1):
    def run_command_and_get_pid(command):
      pid_file = '/dev/shm/pid'
      os.system(f'nohup {command} & echo $! > {pid_file}')
      with open(pid_file, 'r') as f:
        pid = int(f.read().strip())
      os.remove(pid_file)
      return pid
    def is_process_running(pid):
      try:
        os.kill(pid, 0)
      except OSError:
        return False
      else:
        return True

    run_output = widgets.Output()
    progress = widgets.FloatProgress(min=0, max=1, description='running', bar_style='info')
    display(widgets.VBox([progress, run_output]))

    # clear previous run
    for n in range(steps):
      if os.path.isfile(f"/dev/shm/{n}.pdb"):
        os.remove(f"/dev/shm/{n}.pdb")

    pid = run_command_and_get_pid(command)
    try:
      fail = False
      for _ in range(num_designs):
        # for each step
        for n in range(steps):
          wait = True
          while wait and not fail:
            time.sleep(0.5)
            # check if output generated
            if os.path.isfile(f"/dev/shm/{n}.pdb"):
              pdb_str = open(f"/dev/shm/{n}.pdb").read()
              if pdb_str[-3:] == "TER":
                wait = False
              elif not is_process_running(pid):
                fail = True
            elif not is_process_running(pid):
              fail = True

          if fail:
            progress.bar_style = 'danger'
            progress.description = "failed"
            break
          else:
            progress.value = (n+1) / steps
            with run_output:
              run_output.clear_output(wait=True)
              view = py3Dmol.view(js='https://3dmol.org/build/3Dmol.js')
              view.addModel(pdb_str,'pdb')
              view.setStyle({'cartoon': {'colorscheme': {'prop':'b','gradient': 'roygb','min':0.5,'max':0.9}}})
              view.zoomTo()
              view.show()
          if os.path.exists(f"/dev/shm/{n}.pdb"):
            os.remove(f"/dev/shm/{n}.pdb")

        if fail:
          progress.bar_style = 'danger'
          progress.description = "failed"
          break

      while is_process_running(pid):
        time.sleep(0.5)

    except KeyboardInterrupt:
      os.kill(pid, signal.SIGTERM)
      progress.bar_style = 'danger'
      progress.description = "stopped"

if blueprint_mode == "automated":
  pdb_feats = from_pdb(pdb, chains=chain, trim_loops=trim_loops)
  rfdiff = RFdiff_gui(**pdb_feats, name=name, buff_length=(5 if trim_loops else 0))
else:
  rfdiff = RFdiff_gui(elements, name=name)
rfdiff.display_input()

In [None]:
%%time
#@title run **RFdiffusion**
iterations = 25 #@param ["25", "50", "100", "200"] {type:"raw"}
mask_loops = True #@param {type:"boolean"}
mask_contacts = False #@param {type:"boolean"}
#@markdown **Optional**: specify target info (for binder design)
use_target = False #@param {type:"boolean"}
target_pdb = "" #@param {type:"string"}
target_chain = "A" #@param {type:"string"}
target_hotspot = "" #@param {type:"string"}

if use_target:
  # prep target features
  rfdiff._make_path()
  path = f"outputs/{rfdiff.path}/target"
  os.makedirs(path, exist_ok=True)
  target = from_pdb(target_pdb, target_chain, return_pdb_str=True)
  target_pdb = f"{path}/input.pdb"
  with open(target_pdb,"w") as handle:
    handle.write(target["pdb_str"])
    full = get_adj_ss(adj=target["adj"], txt=target["txt"])
    torch.save(torch.from_numpy(full["sse"]).float(),f"{path}/ss.pt")
    torch.save(torch.from_numpy(full["adj"]).float(),f"{path}/adj.pt")

  extra_cmd = ["scaffoldguided.target_pdb=True",
               f"scaffoldguided.target_path={path}/input.pdb",
               f"scaffoldguided.target_ss={path}/ss.pt",
               f"scaffoldguided.target_adj={path}/adj.pt",
               "denoiser.noise_scale_ca=0",
               "denoiser.noise_scale_frame=0"]
  if target_hotspot != "":
    extra_cmd += [f"'ppi.hotspot_res=[{target_hotspot}]'"]
else:
  extra_cmd = None

if "rfdiff" in dir():
  rfdiff.display_output()
  rfdiff.diffuse(iterations,
                 mask_loops=mask_loops,
                 mask_contacts=mask_contacts,
                 extra_cmd=extra_cmd)
else:
  print("Error, looks like you didn't run the cell above")