Commit 38b5cee4 authored by hgarrereyn's avatar hgarrereyn
Browse files

work dump, add onnx stuff

parent 1c333a36
# Lead Optimization via Fragment Prediction
# Lead Optimization
# Overview
# Structure
- `config`: configuration information (eg. TRAIN/TEST partitions)
- `data`: training/inference data (see [`data/README.md`](data/README.md))
......
This diff is collapsed.
This diff is collapsed.
import argparse
import sys
import os
import subprocess
import tempfile
RUN_DIR = '/zfs1/jdurrant/durrantlab/hag63/leadopt_pytorch'
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-t','--time',default='10:00:00')
parser.add_argument('-p','--partition',default='gtx1080')
parser.add_argument('-m','--mem',default='16g')
parser.add_argument('path')
parser.add_argument('script')
args = parser.parse_args()
run_path = os.path.join(RUN_DIR, args.path)
if os.path.exists(run_path):
print('[!] Run exists at %s' % run_path)
overwrite = input('- Overwrite? [Y/n]: ')
if overwrite.lower() == 'n':
print('Exiting...')
exit(0)
else:
print('[*] Creating run directory %s' % run_path)
os.mkdir(run_path)
script = '''#!/bin/bash
#SBATCH --job-name={name}
#SBATCH --output={run_path}/slurm_out.txt
#SBATCH --error={run_path}/slurm_err.txt
#SBATCH --time={time}
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cluster=gpu
#SBATCH --partition={partition}
#SBATCH --mail-user=hag63@pitt.edu
#SBATCH --mail-type=END,FAIL
#SBATCH --gres=gpu:1
#SBATCH --mem={mem}
cd /ihome/jdurrant/hag63/cbio/leadopt_pytorch/
./setup.sh
export PYTHONPATH=/ihome/jdurrant/hag63/cbio/leadopt_pytorch/
# export WANDB_DIR=/ihome/jdurrant/hag63/wandb_abs
export WANDB_DISABLE_CODE=true
cd {run_path}
python /ihome/jdurrant/hag63/cbio/leadopt_pytorch/{script}
'''.format(
name='leadopt_%s' % args.path,
run_path=run_path,
time=args.time,
partition=args.partition,
mem=args.mem,
script=args.script
)
print('[*] Running script...')
with tempfile.NamedTemporaryFile('w') as f:
f.write(script)
f.flush()
r = subprocess.run('sbatch %s' % f.name, shell=True, stdin=subprocess.PIPE, stdout=subprocess.PIPE)
print(r)
if __name__=='__main__':
main()
#!/bin/bash
DEFAULT_RUN_PATH=/zfs1/jdurrant/durrantlab/hag63/leadopt_pytorch/
if [[ $# -ne 4 ]]; then
echo "Usage: $0 <run_name> <gpu_partition> <**args>"
exit 0
fi
ABS_SCRIPT=$(pwd)/train.py
# navigate to runs directory
RUNS_DIR="${RUNS_DIR:-$DEFAULT_RUN_PATH}"
cd $RUNS_DIR
if [[ -d $1 ]]; then
echo "Warning: run directory $1 already exists!"
exit -1
fi
echo "Creating run directory ($1)..."
mkdir $1
echo "Running script..."
sbatch <<EOT
#!/bin/bash
#SBATCH --job-name=$1
#SBATCH --output=$RUNS_DIR/$1/slurm_out.txt
#SBATCH --error=$RUNS_DIR/$1/slurm_err.txt
#SBATCH --time=10:00:00
#SBATCH --nodes=1
#SBATCH --ntasks=1
#SBATCH --cluster=gpu
#SBATCH --partition=$2
#SBATCH --mail-user=hag63@pitt.edu
#SBATCH --mail-type=END,FAIL
#SBATCH --gres=gpu:1
#SBATCH --mem=80g
cd /ihome/jdurrant/hag63/cbio/leadopt_pytorch/
source ./setup.sh
PYTHON_PATH=$PYTHON_PATH:/ihome/jdurrant/hag63/cbio/leadopt_pytorch/
cd $RUNS_DIR/$1
python train.py $4
EOT
......@@ -13,52 +13,9 @@ import h5py
from leadopt.models.voxel import VoxelFingerprintNet2
from leadopt.infer import infer_all
from leadopt.pretrained import MODELS
class SavedModel(object):
model_class = None
model_args = None
@classmethod
def load(cls, path):
m = cls.model_class(**cls.model_args).cuda()
m.load_state_dict(torch.load(path))
m.eval()
return m
@classmethod
def get_fingerprints(cls, path):
f = h5py.File(os.path.join(path, cls.fingerprint_data), 'r')
data = f['fingerprints'][()]
smiles = f['smiles'][()]
f.close()
return data, smiles
class V2_RDK_M150(SavedModel):
model_class = VoxelFingerprintNet2
model_args = {
'in_channels': 18,
'output_size': 2048,
'batchnorm': True,
'sigmoid': True
}
grid_width=24
grid_res=1
receptor_types=[6,7,8,9,15,16,17,35,53]
parent_types=[6,7,8,9,15,16,17,35,53]
fingerprint_data = 'fingerprint_rdk_2048.h5'
MODELS = {
'rdk_m150': V2_RDK_M150
}
def main():
parser = argparse.ArgumentParser()
......
......@@ -8,28 +8,161 @@ import os
from torch.utils.data import DataLoader, Dataset
import numpy as np
import h5py
import tqdm
# Default atomic numbers to use as grid layers
DEFAULT_TYPES = [6,7,8,9,15,16,17,35,53]
def remap_atoms(atom_types):
'''
Returns a function that maps an atomic number to layer index.
def rec_typer_single(t):
'''single channel no hydrogen'''
# unpack features
num, hacc, hdon, aro, _ = t
if num != 1:
return 0
else:
return -1
def rec_typer_single_h(t):
'''single channel with hydrogen'''
# unpack features
return 0
def rec_typer_basic(t):
'''simple types'''
BASIC_TYPES = [6,7,8,16]
num, hacc, hdon, aro, _ = t
if num in BASIC_TYPES:
return BASIC_TYPES.index(num)
else:
return -1
def rec_typer_basic_h(t):
'''simple types with hydrogen'''
BASIC_TYPES = [1,6,7,8,16]
num, hacc, hdon, aro, _ = t
if num in BASIC_TYPES:
return BASIC_TYPES.index(num)
else:
return -1
Params:
- atom_types: which atom types to use as layers
'''
atom_mapping = {atom_types[i]:i for i in range(len(atom_types))}
# aro, hdon, hacc
REC_DESC = [
(6,0,0,0), # C
(6,1,0,0), # C-aro
(7,0,0,0), # N
(7,0,0,1), # N-hacc
(7,0,1,0), # N-hdon
(7,1,0,0), # N-aro
(7,0,1,1), # N-hdon-hacc
(7,1,0,1), # N-aro-hacc
(7,1,1,0), # N-aro-hdon
(8,0,0,0), # O
(8,0,0,1), # O-hacc
(8,0,1,1), # O-hdon-hacc
(16,0,0,0), #S
]
REC_DESC_H = [
(1,0,0,0), # H
(6,0,0,0), # C
(6,1,0,0), # C-hacc
(7,0,0,0), # N
(7,0,0,1), # N-aro
(7,0,1,0), # N-hdon
(7,1,0,0), # N-hacc
(7,0,1,1), # N-hdon-aro
(7,1,0,1), # N-hacc-aro
(7,1,1,0), # N-hacc-hdon
(8,0,0,0), # O
(8,0,0,1), # O-aro
(8,0,1,1), # O-hdon-aro
(16,0,0,0), #S
]
def rec_typer_desc(t):
'''descriptive types'''
num, hacc, hdon, aro, _ = t
f = (num,hacc,hdon,aro)
if f in REC_DESC:
return REC_DESC.index(f)
else:
return -1
def rec_typer_desc_h(t):
'''descriptive types'''
num, hacc, hdon, aro, _ = t
f = (num,hacc,hdon,aro)
if f in REC_DESC_H:
return REC_DESC_H.index(f)
else:
return -1
def lig_typer_single(t):
if t != 1:
return 0
else:
return -1
def lig_typer_single_h(t):
return 0
def lig_typer_simple(t):
BASIC_TYPES = [6,7,8]
if t in BASIC_TYPES:
return BASIC_TYPES.index(t)
else:
return -1
def lig_typer_simple_h(t):
BASIC_TYPES = [1,6,7,8]
def f(x):
if x in atom_mapping:
return atom_mapping[x]
else:
return -1
if t in BASIC_TYPES:
return BASIC_TYPES.index(t)
else:
return -1
return f
def lig_typer_desc(t):
DESC_TYPES = [6,7,8,16,9,15,17,35,5,53]
if t in DESC_TYPES:
return DESC_TYPES.index(t)
else:
return -1
def lig_typer_desc_h(t):
DESC_TYPES = [1,6,7,8,16,9,15,17,35,5,53]
if t in DESC_TYPES:
return DESC_TYPES.index(t)
else:
return -1
REC_TYPER = {
'single': rec_typer_single,
'single_h': rec_typer_single_h,
'simple': rec_typer_basic,
'simple_h': rec_typer_basic_h,
'desc': rec_typer_desc,
'desc_h': rec_typer_desc_h
}
LIG_TYPER = {
'single': lig_typer_single,
'single_h': lig_typer_single_h,
'simple': lig_typer_simple,
'simple_h': lig_typer_simple_h,
'desc': lig_typer_desc,
'desc_h': lig_typer_desc_h
}
class FragmentDataset(Dataset):
......@@ -37,15 +170,17 @@ class FragmentDataset(Dataset):
Utility class to work with the packed fragments.h5 format
'''
def __init__(self, fragment_file, fingerprint_file, filter_rec=None, atom_types=DEFAULT_TYPES, fdist_min=None, fdist_max=None, fmass_min=None, fmass_max=None):
def __init__(self, fragment_file, fingerprint_file, rec_typer, lig_typer, filter_rec=None,
fdist_min=None, fdist_max=None, fmass_min=None, fmass_max=None, verbose=False):
'''
Initialize the fragment dataset
Params:
- fragment_file: path to fragments.h5
- fingerprint_file: path to fingerprints.h5
- rec_typer: function to map receptor rows to layer index
- lig_typer: function to map ligand rows to layer index
- filter_rec: list of receptor ids to use (or None to use all)
- atom_types: which atom types to use as layers
Filtering options:
- fdist_min: minimum fragment distance
......@@ -53,9 +188,11 @@ class FragmentDataset(Dataset):
- fmass_min: minimum fragment mass (Da)
- fmass_max: maximum fragment mass (Da)
'''
self.verbose = verbose
# load receptor/fragment information
self.rec = self.load_rec(fragment_file, atom_types)
self.frag = self.load_fragments(fragment_file, atom_types)
self.rec = self.load_rec(fragment_file, rec_typer)
self.frag = self.load_fragments(fragment_file, lig_typer)
# load fingerprint information
self.fingerprints = self.load_fingerprints(fingerprint_file)
......@@ -91,19 +228,22 @@ class FragmentDataset(Dataset):
self.valid_fingerprints = self.compute_valid_fingerprints()
def load_rec(self, fragment_file, atom_types):
def load_rec(self, fragment_file, rec_typer):
'''Load receptor information'''
f = h5py.File(fragment_file, 'r')
rec_data = f['rec_data'][()]
rec_coords = f['rec_coords'][()]
rec_types = f['rec_types'][()]
rec_lookup = f['rec_lookup'][()]
# unpack rec data into separate structures
rec_coords = rec_data[:,:3].astype(np.float32)
rec_types = rec_data[:,3].reshape(-1,1).astype(np.int32)
rec_remapped = np.vectorize(remap_atoms(atom_types))(rec_types)
r = range(len(rec_types))
if self.verbose:
r = tqdm.tqdm(r, desc='Remap receptor atoms')
rec_remapped = np.zeros(len(rec_types)).astype(np.int32)
for i in r:
rec_remapped[i] = rec_typer(rec_types[i])
# create rec mapping
rec_mapping = {}
for i in range(len(rec_lookup)):
......@@ -121,7 +261,7 @@ class FragmentDataset(Dataset):
return rec
def load_fragments(self, fragment_file, atom_types):
def load_fragments(self, fragment_file, lig_typer):
'''Load fragment information'''
f = h5py.File(fragment_file, 'r')
......@@ -134,13 +274,17 @@ class FragmentDataset(Dataset):
# unpack frag data into separate structures
frag_coords = frag_data[:,:3].astype(np.float32)
frag_types = frag_data[:,3].reshape(-1,1).astype(np.int32)
frag_types = frag_data[:,3].astype(np.int32)
frag_remapped = np.vectorize(remap_atoms(atom_types))(frag_types)
frag_remapped = np.vectorize(lig_typer)(frag_types)
# find and save connection point
r = range(len(frag_lookup))
if self.verbose:
r = tqdm.tqdm(r, desc='Frag connection point')
frag_conn = np.zeros((len(frag_lookup), 3))
for i in range(len(frag_lookup)):
for i in r:
_,f_start,f_end,_,_ = frag_lookup[i]
fdat = frag_data[f_start:f_end]
......
......@@ -67,7 +67,8 @@ def gpu_gridify(grid, width, res, center, rot, atom_num, atom_coords, atom_types
while i < atom_num:
# fetch atom
fx, fy, fz = atom_coords[i]
ft = atom_types[i][0]
# ft = atom_types[i][0]
ft = atom_types[i]
i += 1
# invisible atoms
......@@ -152,13 +153,15 @@ def mol_gridify(
)
def get_batch(data, batch_set=None, batch_size=16, width=48, res=0.5, ignore_receptor=False, ignore_parent=False, include_freq=False):
def get_batch(data, rec_channels, parent_channels, batch_set=None, batch_size=16, width=48, res=0.5, ignore_receptor=False, ignore_parent=False, include_freq=False):
assert (not (ignore_receptor and ignore_parent)), "Can't ignore parent and receptor!"
dim = 18
if ignore_receptor or ignore_parent:
dim = 9
dim = 0
if not ignore_receptor:
dim += rec_channels
if not ignore_parent:
dim += parent_channels
# create a tensor with shared memory on the gpu
t, grid = make_tensor((batch_size, dim, width, width, width))
......@@ -182,7 +185,7 @@ def get_batch(data, batch_set=None, batch_size=16, width=48, res=0.5, ignore_rec
mol_gridify(grid, r_coords, r_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
else:
mol_gridify(grid, p_coords, p_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
mol_gridify(grid, r_coords, r_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=9)
mol_gridify(grid, r_coords, r_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=parent_channels)
fingerprints[i] = fp
freq[i] = extra['freq']
......@@ -196,63 +199,63 @@ def get_batch(data, batch_set=None, batch_size=16, width=48, res=0.5, ignore_rec
return t, t_fingerprints, batch_set
def get_batch_dual(data, batch_set=None, batch_size=16, width=48, res=0.5, ignore_receptor=False, ignore_parent=False):
# def get_batch_dual(data, batch_set=None, batch_size=16, width=48, res=0.5, ignore_receptor=False, ignore_parent=False):
# get batch
t, fp, batch_set = get_batch(data, batch_set, batch_size, width, res, ignore_receptor, ignore_parent)
# # get batch
# t, fp, batch_set = get_batch(data, batch_set, batch_size, width, res, ignore_receptor, ignore_parent)
f = data.fingerprints['fingerprint_data']
# f = data.fingerprints['fingerprint_data']
# corrupt fingerprints
false_fp = torch.clone(fp)
for i in range(batch_size):
# idx = np.random.randint(fp.shape[1])
# false_fp[i,idx] = (1 - false_fp[i,idx]) # flip
idx = np.random.randint(f.shape[0])
false_fp[i] = torch.Tensor(f[idx]) # replace
# # corrupt fingerprints
# false_fp = torch.clone(fp)
# for i in range(batch_size):
# # idx = np.random.randint(fp.shape[1])
# # false_fp[i,idx] = (1 - false_fp[i,idx]) # flip
# idx = np.random.randint(f.shape[0])
# false_fp[i] = torch.Tensor(f[idx]) # replace
comb_t = torch.cat([t,t], axis=0)
comb_fp = torch.cat([fp, false_fp], axis=0)
# comb_t = torch.cat([t,t], axis=0)
# comb_fp = torch.cat([fp, false_fp], axis=0)
y = torch.zeros((batch_size * 2,1)).cuda()
y[:batch_size] = 1
# y = torch.zeros((batch_size * 2,1)).cuda()
# y[:batch_size] = 1
return (comb_t, comb_fp, y, batch_set)
# return (comb_t, comb_fp, y, batch_set)
def get_batch_full(data, batch_set=None, batch_size=16, width=48, res=0.5, ignore_receptor=False, ignore_parent=False):
# def get_batch_full(data, batch_set=None, batch_size=16, width=48, res=0.5, ignore_receptor=False, ignore_parent=False):
assert (not (ignore_receptor and ignore_parent)), "Can't ignore parent and receptor!"
# assert (not (ignore_receptor and ignore_parent)), "Can't ignore parent and receptor!"
dim = 18
if ignore_receptor or ignore_parent:
dim = 9
# dim = 18
# if ignore_receptor or ignore_parent:
# dim = 9
# create a tensor with shared memory on the gpu
t_context, grid_context = make_tensor((batch_size, dim, width, width, width))
t_frag, grid_frag = make_tensor((batch_size, 9, width, width, width))
# # create a tensor with shared memory on the gpu
# t_context, grid_context = make_tensor((batch_size, dim, width, width, width))
# t_frag, grid_frag = make_tensor((batch_size, 9, width, width, width))
if batch_set is None:
batch_set = np.random.choice(len(data), size=batch_size, replace=False)
# if batch_set is None:
# batch_set = np.random.choice(len(data), size=batch_size, replace=False)
for i in range(len(batch_set)):
idx = batch_set[i]
f_coords, f_types, p_coords, p_types, r_coords, r_types, conn, fp = data[idx]
# for i in range(len(batch_set)):
# idx = batch_set[i]
# f_coords, f_types, p_coords, p_types, r_coords, r_types, conn, fp = data[idx]
# random rotation
rot = rand_rot()
# # random rotation
# rot = rand_rot()
if ignore_receptor:
mol_gridify(grid_context, p_coords, p_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
elif ignore_parent:
mol_gridify(grid_context, r_coords, r_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
else:
mol_gridify(grid_context, p_coords, p_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
mol_gridify(grid_context, r_coords, r_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=9)
# if ignore_receptor:
# mol_gridify(grid_context, p_coords, p_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
# elif ignore_parent:
# mol_gridify(grid_context, r_coords, r_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
# else:
# mol_gridify(grid_context, p_coords, p_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
# mol_gridify(grid_context, r_coords, r_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=9)
mol_gridify(grid_frag, f_coords, f_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
# mol_gridify(grid_frag, f_coords, f_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
return t_context, t_frag, batch_set
# return t_context, t_frag, batch_set
def get_raw_batch(r_coords, r_types, p_coords, p_types, conn, num_samples=32, width=24, res=1, r_dim=9, p_dim=9):
......
......@@ -12,7 +12,7 @@ import h5py
import tqdm
from leadopt.grid_util import get_raw_batch