Commit 78833f06 authored by Harrison Green's avatar Harrison Green
Browse files

init

parents
**/*.pyc
**/__pycache__
.ipynb_checkpoints
data/**
!data/README.md
pretrained/**
!pretrained/README.md
.DS_Store
# Lead Optimization via Fragment Prediction
# Overview
- `config`: configuration information (eg. TRAIN/TEST partitions)
- `data`: training/inference data (see [`data/README.md`](data/README.md))
- `docker`: Docker environment
- `leadopt`: main module code
- `models`: architecture definitions
- `data_util.py`: utility wrapper code around fragment and fingerprint datasets
- `grid_util.py`: GPU-accelerated grid generation code
- `infer.py`: code for inference with a trained model
- `metrics.py`
- `train.py`: training loops
- `util.py`: extra utility code (mostly rdkit)
- `pretained`: pretrained models (see [`pretrained/README.md`](pretrained/README.md))
- `scripts`: data processing scripts (see [`scripts/README.md`](scripts/README.md))
- `train.py`: CLI interface to launch training runs
- `leadopt.py`: CLI interface to run inference on new samples
# Training
You can train models with the `train.py` utility script
This diff is collapsed.
This folder contains data used during training and inference.
You can download the data here: ...
'''
fragment prediction CLI tool
'''
import argparse
import os
import torch
import torch.nn as nn
import torch.nn.functional as F
import h5py
from leadopt.models.voxel import VoxelFingerprintNet2
from leadopt.infer import infer_all
class SavedModel(object):
model_class = None
model_args = None
@classmethod
def load(cls, path):
m = cls.model_class(**cls.model_args).cuda()
m.load_state_dict(torch.load(path))
m.eval()
return m
@classmethod
def get_fingerprints(cls, path):
f = h5py.File(os.path.join(path, cls.fingerprint_data), 'r')
data = f['fingerprints'][()]
smiles = f['smiles'][()]
f.close()
return data, smiles
class V2_RDK_M150(SavedModel):
model_class = VoxelFingerprintNet2
model_args = {
'in_channels': 18,
'output_size': 2048,
'batchnorm': True,
'sigmoid': True
}
grid_width=24
grid_res=1
receptor_types=[6,7,8,9,15,16,17,35,53]
parent_types=[6,7,8,9,15,16,17,35,53]
fingerprint_data = 'fingerprint_rdk_2048.h5'
MODELS = {
'rdk_m150': V2_RDK_M150
}
def main():
parser = argparse.ArgumentParser()
parser.add_argument('-r', '--receptor', required=True, help='Receptor file (.pdb)')
parser.add_argument('-l', '--ligand', required=True, help='Ligand file (.sdf)')
parser.add_argument('-n', '--num_samples', type=int, default=16, help='Number of random rotation samples to use per prediction (default: 16)')
parser.add_argument('-k', '--num_suggestions', type=int, default=25, help='Number of suggestions per fragment')
parser.add_argument('-m', '--model', default=[k for k in MODELS][0], choices=[k for k in MODELS])
parser.add_argument('-mp', '--model_path', required=True)
parser.add_argument('-d', '--data_path', required=True)
args = parser.parse_args()
# load model
m = MODELS[args.model].load(args.model_path)
fingerprints, smiles = MODELS[args.model].get_fingerprints(args.data_path)
# run infer step
res = infer_all(
model=m,
fingerprints=fingerprints,
smiles=smiles,
rec_path=args.receptor,
lig_path=args.ligand,
num_samples=args.num_samples,
k=args.num_suggestions
)
print(res)
if __name__=='__main__':
main()
# leadopt.py -r my_receptor.pdb -l my_ligand.sdf
\ No newline at end of file
'''
data_util.py
contains utility code for reading packed training data
'''
import os
from torch.utils.data import DataLoader, Dataset
import numpy as np
import h5py
# Default atomic numbers to use as grid layers
DEFAULT_TYPES = [6,7,8,9,15,16,17,35,53]
def remap_atoms(atom_types):
'''
Returns a function that maps an atomic number to layer index.
Params:
- atom_types: which atom types to use as layers
'''
atom_mapping = {atom_types[i]:i for i in range(len(atom_types))}
def f(x):
if x in atom_mapping:
return atom_mapping[x]
else:
return -1
return f
class FragmentDataset(Dataset):
'''
Utility class to work with the packed fragments.h5 format
'''
def __init__(self, fragment_file, fingerprint_file, filter_rec=None, atom_types=DEFAULT_TYPES, fdist_min=None, fdist_max=None, fmass_min=None, fmass_max=None):
'''
Initialize the fragment dataset
Params:
- fragment_file: path to fragments.h5
- fingerprint_file: path to fingerprints.h5
- filter_rec: list of receptor ids to use (or None to use all)
- atom_types: which atom types to use as layers
Filtering options:
- fdist_min: minimum fragment distance
- fdist_max: maximum fragment distance
- fmass_min: minimum fragment mass (Da)
- fmass_max: maximum fragment mass (Da)
'''
# load receptor/fragment information
self.rec = self.load_rec(fragment_file, atom_types)
self.frag = self.load_fragments(fragment_file, atom_types)
# load fingerprint information
self.fingerprints = self.load_fingerprints(fingerprint_file)
# keep track of valid examples
valid_mask = np.ones(self.frag['frag_lookup'].shape[0]).astype(np.bool)
# filter by receptor id
if filter_rec is not None:
valid_rec = np.vectorize(lambda k: k[0].decode('ascii') in filter_rec)(self.frag['frag_lookup'])
valid_mask *= valid_rec
# filter by fragment distance
if fdist_min is not None:
valid_mask[self.frag['frag_dist'] < fdist_min] = 0
if fdist_max is not None:
valid_mask[self.frag['frag_dist'] > fdist_max] = 0
# filter by fragment mass
if fmass_min is not None:
valid_mask[self.frag['frag_mass'] < fmass_min] = 0
if fmass_max is not None:
valid_mask[self.frag['frag_mass'] > fmass_max] = 0
# convert to a list of indexes
self.valid_idx = np.where(valid_mask)[0]
# compute frequency metrics over valid fragments
# (valid idx -> (smiles count))
self.freq = self.compute_freq()
self.valid_fingerprints = self.compute_valid_fingerprints()
def load_rec(self, fragment_file, atom_types):
'''Load receptor information'''
f = h5py.File(fragment_file, 'r')
rec_data = f['rec_data'][()]
rec_lookup = f['rec_lookup'][()]
# unpack rec data into separate structures
rec_coords = rec_data[:,:3].astype(np.float32)
rec_types = rec_data[:,3].reshape(-1,1).astype(np.int32)
rec_remapped = np.vectorize(remap_atoms(atom_types))(rec_types)
# create rec mapping
rec_mapping = {}
for i in range(len(rec_lookup)):
rec_mapping[rec_lookup[i][0].decode('ascii')] = i
rec = {
'rec_coords': rec_coords,
'rec_types': rec_types,
'rec_remapped': rec_remapped,
'rec_lookup': rec_lookup,
'rec_mapping': rec_mapping
}
f.close()
return rec
def load_fragments(self, fragment_file, atom_types):
'''Load fragment information'''
f = h5py.File(fragment_file, 'r')
frag_data = f['frag_data'][()]
frag_lookup = f['frag_lookup'][()]
frag_smiles = f['frag_smiles'][()]
frag_mass = f['frag_mass'][()]
frag_dist = f['frag_dist'][()]
# unpack frag data into separate structures
frag_coords = frag_data[:,:3].astype(np.float32)
frag_types = frag_data[:,3].reshape(-1,1).astype(np.int32)
frag_remapped = np.vectorize(remap_atoms(atom_types))(frag_types)
# find and save connection point
frag_conn = np.zeros((len(frag_lookup), 3))
for i in range(len(frag_lookup)):
_,f_start,f_end,_,_ = frag_lookup[i]
fdat = frag_data[f_start:f_end]
found = False
for j in range(len(fdat)):
if fdat[j][3] == 0:
frag_conn[i,:] = tuple(fdat[j])[:3]
found = True
break
assert found, "missing fragment connection point at %d" % i
frag = {
'frag_coords': frag_coords, # d_idx -> (x,y,z)
'frag_types': frag_types, # d_idx -> (type)
'frag_remapped': frag_remapped, # d_idx -> (layer)
'frag_lookup': frag_lookup, # f_idx -> (rec_id, fstart, fend, pstart, pend)
'frag_conn': frag_conn, # f_idx -> (x,y,z)
'frag_smiles': frag_smiles, # f_idx -> smiles
'frag_mass': frag_mass, # f_idx -> mass
'frag_dist': frag_dist, # f_idx -> dist
}
f.close()
return frag
def load_fingerprints(self, fingerprint_file):
'''load fingerprint information'''
f = h5py.File(fingerprint_file, 'r')
fingerprint_data = f['fingerprints'][()]
fingerprint_smiles = f['smiles'][()]
# create smiles->idx mapping
fingerprint_mapping = {}
for i in range(len(fingerprint_smiles)):
sm = fingerprint_smiles[i].decode('ascii')
fingerprint_mapping[sm] = i
fingerprints = {
'fingerprint_data': fingerprint_data,
'fingerprint_mapping': fingerprint_mapping,
'fingerprint_smiles': fingerprint_smiles,
}
f.close()
return fingerprints
def compute_freq(self):
'''compute fragment frequencies'''
all_smiles = self.frag['frag_smiles']
valid_smiles = all_smiles[self.valid_idx]
smiles_freq = {}
for i in range(len(valid_smiles)):
sm = valid_smiles[i].decode('ascii')
if not sm in smiles_freq:
smiles_freq[sm] = 0
smiles_freq[sm] += 1
freq = np.zeros(len(valid_smiles))
for i in range(len(valid_smiles)):
freq[i] = smiles_freq[valid_smiles[i].decode('ascii')]
return freq
def compute_valid_fingerprints(self):
'''compute a list of valid fingerprint indexes'''
valid_sm = self.frag['frag_smiles'][self.valid_idx]
valid_sm = list(set(list(valid_sm))) # unique
valid_idx = []
for sm in valid_sm:
valid_idx.append(self.fingerprints['fingerprint_mapping'][sm.decode('ascii')])
valid_idx = sorted(valid_idx)
return valid_idx
def normalize_fingerprints(self, std, mean):
'''normalize fingerprints with a given std and mean'''
self.fingerprints['fingerprint_data'] -= mean
self.fingerprints['fingerprint_data'] /= std
def __len__(self):
'''returns the number of fragment examples'''
return self.valid_idx.shape[0]
def __getitem__(self, idx):
'''
retrieve the nth example
returns (f_coords, f_types, p_coords, p_types, r_coords, r_types, conn, fingerprint, extra)
'''
# convert to fragment index
frag_idx = self.valid_idx[idx]
# lookup fragment
rec_id, f_start, f_end, p_start, p_end = self.frag['frag_lookup'][frag_idx]
smiles = self.frag['frag_smiles'][frag_idx]
conn = self.frag['frag_conn'][frag_idx]
# lookup fingerprint
fingerprint = self.fingerprints['fingerprint_data'][
self.fingerprints['fingerprint_mapping'][smiles.decode('ascii')]
]
# lookup receptor
rec_idx = self.rec['rec_mapping'][rec_id.decode('ascii')]
_, r_start, r_end = self.rec['rec_lookup'][rec_idx]
# fetch data
f_coords = self.frag['frag_coords'][f_start:f_end]
f_types = self.frag['frag_remapped'][f_start:f_end]
p_coords = self.frag['frag_coords'][p_start:p_end]
p_types = self.frag['frag_remapped'][p_start:p_end]
r_coords = self.rec['rec_coords'][r_start:r_end]
r_types = self.rec['rec_remapped'][r_start:r_end]
extra = {
'freq': self.freq[idx]
}
return f_coords, f_types, p_coords, p_types, r_coords, r_types, conn, fingerprint, extra
\ No newline at end of file
'''
grid_util.py
contains code for gpu-accelerated grid generation
'''
import math
import ctypes
import torch
import numba
import numba.cuda
import numpy as np
GPU_DIM = 8
@numba.cuda.jit
def gpu_gridify(grid, width, res, center, rot, atom_num, atom_coords, atom_types, layer_offset, batch_i):
'''
GPU kernel to add atoms to a grid
width, res, offset and rot control the grid view
'''
x,y,z = numba.cuda.grid(3)
# center around origin
tx = x - (width/2)
ty = y - (width/2)
tz = z - (width/2)
# scale by resolution
tx = tx * res
ty = ty * res
tz = tz * res
# apply rotation vector
aw = rot[0]
ax = rot[1]
ay = rot[2]
az = rot[3]
bw = 0
bx = tx
by = ty
bz = tz
# multiply by rotation vector
cw = (aw * bw) - (ax * bx) - (ay * by) - (az * bz)
cx = (aw * bx) + (ax * bw) + (ay * bz) - (az * by)
cy = (aw * by) + (ay * bw) + (az * bx) - (ax * bz)
cz = (aw * bz) + (az * bw) + (ax * by) - (ay * bx)
# multiply by conjugate
# dw = (cw * aw) - (cx * (-ax)) - (cy * (-ay)) - (cz * (-az))
dx = (cw * (-ax)) + (cx * aw) + (cy * (-az)) - (cz * (-ay))
dy = (cw * (-ay)) + (cy * aw) + (cz * (-ax)) - (cx * (-az))
dz = (cw * (-az)) + (cz * aw) + (cx * (-ay)) - (cy * (-ax))
# apply translation vector
tx = dx + center[0]
ty = dy + center[1]
tz = dz + center[2]
i = 0
while i < atom_num:
# fetch atom
fx, fy, fz = atom_coords[i]
ft = atom_types[i][0]
i += 1
# invisible atoms
if ft == -1:
continue
r2 = 4
# exit early
if abs(fx-tx) > r2 or abs(fy-ty) > r2 or abs(fz-tz) > r2:
continue
# compute squared distance to atom
d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
# compute effect
v = math.exp((-2 * d2) / r2)
# add effect
if d2 < r2:
grid[batch_i,layer_offset+ft,x,y,z] += v
def make_tensor(shape):
'''
Create a pytorch tensor and numba array with the same GPU memory backing
'''
# get cuda context
ctx = numba.cuda.cudadrv.driver.driver.get_active_context()
# setup tensor on gpu
t = torch.zeros(size=shape, dtype=torch.float32).cuda()
memory = numba.cuda.cudadrv.driver.MemoryPointer(ctx, ctypes.c_ulong(t.data_ptr()), t.numel() * 4)
cuda_arr = numba.cuda.cudadrv.devicearray.DeviceNDArray(
t.size(),
[i*4 for i in t.stride()],
np.dtype('float32'),
gpu_data=memory,
stream=torch.cuda.current_stream().cuda_stream
)
return (t, cuda_arr)
def rand_rot():
'''
Sample a random 3d rotation from a uniform distribution
Returns a quaternion vector (w,x,y,z)
'''
q = np.random.normal(size=4) # sample quaternion from normal distribution
q = q / np.sqrt(np.sum(q**2)) # normalize
return q
def mol_gridify(
grid,
mol_atoms,
mol_types,
batch_i,
center=np.array([0,0,0]),
width=48,
res=0.5,
rot=np.array([1,0,0,0]),
layer_offset=0,
):
'''wrapper to invoke gpu gridify kernel'''
dw = ((width - 1) // GPU_DIM) + 1
gpu_gridify[(dw,dw,dw), (GPU_DIM,GPU_DIM,GPU_DIM)](
grid,
width,
res,
center,
rot,
len(mol_atoms),
mol_atoms,
mol_types,
layer_offset,
batch_i
)
def get_batch(data, batch_set=None, batch_size=16, width=48, res=0.5, ignore_receptor=False, ignore_parent=False, include_freq=False):
assert (not (ignore_receptor and ignore_parent)), "Can't ignore parent and receptor!"
dim = 18
if ignore_receptor or ignore_parent:
dim = 9
# create a tensor with shared memory on the gpu
t, grid = make_tensor((batch_size, dim, width, width, width))
if batch_set is None:
batch_set = np.random.choice(len(data), size=batch_size, replace=False)
fingerprints = np.zeros((batch_size, data.fingerprints['fingerprint_data'].shape[1]))
freq = np.zeros(batch_size)
for i in range(len(batch_set)):
idx = batch_set[i]
f_coords, f_types, p_coords, p_types, r_coords, r_types, conn, fp, extra = data[idx]
# random rotation
rot = rand_rot()
if ignore_receptor:
mol_gridify(grid, p_coords, p_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
elif ignore_parent:
mol_gridify(grid, r_coords, r_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
else:
mol_gridify(grid, p_coords, p_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=0)
mol_gridify(grid, r_coords, r_types, batch_i=i, center=conn, width=width, res=res, rot=rot, layer_offset=9)
fingerprints[i] = fp
freq[i] = extra['freq']
t_fingerprints = torch.Tensor(fingerprints).cuda()
t_freq = torch.Tensor(freq).cuda()
if include_freq: