Commit 778c3ec1 authored by hgarrereyn's avatar hgarrereyn
Browse files

misc updates

parent d958ceca
......@@ -43,7 +43,7 @@ class AtomTyper(object):
class CondAtomTyper(AtomTyper):
def __init__(self, cond_func):
assert len(cond_func) <= 32
assert len(cond_func) <= 16
def _fn(*args):
v = 0
for k in range(len(cond_func)):
......@@ -183,12 +183,12 @@ class FragmentDataset(Dataset):
if self.verbose:
r = tqdm.tqdm(r, desc='Remap receptor atoms')
rec_remapped = np.zeros(len(rec_types), dtype=np.uint32)
rec_remapped = np.zeros(len(rec_types), dtype=np.uint16)
if not self._lazy_loading:
for i in r:
rec_remapped[i] = rec_typer.apply(*rec_types[i])
rec_loaded = np.zeros(len(rec_lookup)).astype(np.int32)
rec_loaded = np.zeros(len(rec_lookup)).astype(np.bool)
# create rec mapping
rec_mapping = {}
......@@ -226,14 +226,14 @@ class FragmentDataset(Dataset):
# unpack frag data into separate structures
frag_coords = frag_data[:,:3].astype(np.float32)
frag_types = frag_data[:,3].astype(np.int32)
frag_types = frag_data[:,3].astype(np.uint8)
frag_remapped = np.zeros(len(frag_types), dtype=np.uint32)
frag_remapped = np.zeros(len(frag_types), dtype=np.uint16)
if not self._lazy_loading:
for i in range(len(frag_types)):
frag_remapped[i] = lig_typer.apply(frag_types[i])
frag_loaded = np.zeros(len(frag_lookup)).astype(np.int32)
frag_loaded = np.zeros(len(frag_lookup)).astype(np.bool)
# find and save connection point
r = range(len(frag_lookup))
......@@ -353,7 +353,9 @@ class FragmentDataset(Dataset):
"""
# convert to fragment index
frag_idx = self.valid_idx[idx]
return self.get_raw(frag_idx)
def get_raw(self, frag_idx):
# lookup fragment
rec_id, f_start, f_end, p_start, p_end = self.frag['frag_lookup'][frag_idx]
smiles = self.frag['frag_smiles'][frag_idx].decode('ascii')
......@@ -420,6 +422,39 @@ class FragmentDataset(Dataset):
return self._rec_typer.size()
class SharedFragmentDataset(object):
def __init__(self, dat, filter_rec=None, filter_smi=None, fdist_min=None,
fdist_max=None, fmass_min=None, fmass_max=None):
self._dat = dat
self.valid_idx = self._dat._get_valid_examples(
filter_rec, filter_smi, fdist_min, fdist_max, fmass_min, fmass_max, verbose=True)
def __len__(self):
return self.valid_idx.shape[0]
def __getitem__(self, idx):
frag_idx = self.valid_idx[idx]
return self._dat.get_raw(frag_idx)
def get_valid_smiles(self):
"""Returns a list of all valid smiles fragments."""
valid_smiles = set()
for idx in self.valid_idx:
smiles = self._dat.frag['frag_smiles'][idx].decode('ascii')
valid_smiles.add(smiles)
return list(valid_smiles)
def lig_layers(self):
return self._dat.lig_layers()
def rec_layers(self):
return self._dat.rec_layers()
class FingerprintDataset(Dataset):
def __init__(self, fingerprint_file):
......
......@@ -265,6 +265,9 @@ def get_batch(data, batch_size=16, batch_set=None, width=48, res=0.5,
"""
assert (not (ignore_receptor and ignore_parent)), "Can't ignore parent and receptor!"
batch_size = int(batch_size)
width = int(width)
rec_channels = data.rec_layers()
lig_channels = data.lig_layers()
......
......@@ -149,7 +149,7 @@ def top_k_acc(fingerprints, fn, k, pre=''):
for j in range(len(k)):
c[i,j] = int(count < k[j])
score = torch.mean(c, axis=0)
score = torch.mean(c, 0)
m = {'%sacc_%d' % (pre, h): v.item() for h,v in zip(k,score)}
return m
......
......@@ -10,7 +10,7 @@ import tqdm
import numpy as np
from leadopt.models.voxel import VoxelFingerprintNet
from leadopt.data_util import FragmentDataset, FingerprintDataset, LIG_TYPER,\
from leadopt.data_util import FragmentDataset, SharedFragmentDataset, FingerprintDataset, LIG_TYPER,\
REC_TYPER
from leadopt.grid_util import get_batch
from leadopt.metrics import mse, bce, tanimoto, cos, top_k_acc,\
......@@ -293,39 +293,74 @@ class VoxelNet(LeadoptModel):
).to(self._device)
return {'voxel': voxel}
def train(self, save_path=None):
def load_data(self):
print('[*] Loading data...', flush=True)
train_dat = FragmentDataset(
# train_dat = FragmentDataset(
# self._args['fragments'],
# rec_typer=REC_TYPER[self._args['rec_typer']],
# lig_typer=LIG_TYPER[self._args['lig_typer']],
# # filter_rec=(
# # partitions.TRAIN if not self._args['no_partitions'] else None),
# filter_rec=set(get_bios(moad_partitions.TRAIN)),
# filter_smi=set(moad_partitions.TRAIN_SMI),
# fdist_min=self._args['fdist_min'],
# fdist_max=self._args['fdist_max'],
# fmass_min=self._args['fmass_min'],
# fmass_max=self._args['fmass_max'],
# verbose=True
# )
# val_dat = FragmentDataset(
# self._args['fragments'],
# rec_typer=REC_TYPER[self._args['rec_typer']],
# lig_typer=LIG_TYPER[self._args['lig_typer']],
# # filter_rec=(
# # partitions.VAL if not self._args['no_partitions'] else None),
# filter_rec=set(get_bios(moad_partitions.VAL)),
# filter_smi=set(moad_partitions.VAL_SMI),
# fdist_min=self._args['fdist_min'],
# fdist_max=self._args['fdist_max'],
# fmass_min=self._args['fmass_min'],
# fmass_max=self._args['fmass_max'],
# verbose=True
# )
dat = FragmentDataset(
self._args['fragments'],
rec_typer=REC_TYPER[self._args['rec_typer']],
lig_typer=LIG_TYPER[self._args['lig_typer']],
# filter_rec=(
# partitions.TRAIN if not self._args['no_partitions'] else None),
verbose=True
)
train_dat = SharedFragmentDataset(
dat,
filter_rec=set(get_bios(moad_partitions.TRAIN)),
filter_smi=set(moad_partitions.TRAIN_SMI),
fdist_min=self._args['fdist_min'],
fdist_max=self._args['fdist_max'],
fmass_min=self._args['fmass_min'],
fmass_max=self._args['fmass_max'],
verbose=True
)
val_dat = FragmentDataset(
self._args['fragments'],
rec_typer=REC_TYPER[self._args['rec_typer']],
lig_typer=LIG_TYPER[self._args['lig_typer']],
# filter_rec=(
# partitions.VAL if not self._args['no_partitions'] else None),
val_dat = SharedFragmentDataset(
dat,
filter_rec=set(get_bios(moad_partitions.VAL)),
filter_smi=set(moad_partitions.VAL_SMI),
fdist_min=self._args['fdist_min'],
fdist_max=self._args['fdist_max'],
fmass_min=self._args['fmass_min'],
fmass_max=self._args['fmass_max'],
verbose=True
)
return train_dat, val_dat
def train(self, save_path=None, custom_steps=None, checkpoint_callback=None, data=None):
if data is None:
data = self.load_data()
train_dat, val_dat = data
fingerprints = FingerprintDataset(self._args['fingerprints'])
train_smiles = train_dat.get_valid_smiles()
......@@ -345,10 +380,17 @@ class VoxelNet(LeadoptModel):
print('[*] Val smiles: %d' % val_fingerprints.shape[0])
print('[*] All smiles: %d' % all_fingerprints.shape[0])
# memory optimization, drop some unnecessary columns
train_dat._dat.frag['frag_mass'] = None
train_dat._dat.frag['frag_dist'] = None
train_dat._dat.frag['frag_lig_smi'] = None
train_dat._dat.frag['frag_lig_idx'] = None
print('[*] Training...', flush=True)
opt = torch.optim.Adam(
self._models['voxel'].parameters(), lr=self._args['learning_rate'])
steps_per_epoch = len(train_dat) // self._args['batch_size']
steps_per_epoch = custom_steps if custom_steps is not None else steps_per_epoch
# configure metrics
dist_fn = DIST_FN[self._args['dist_fn']]
......@@ -360,11 +402,11 @@ class VoxelNet(LeadoptModel):
loss_fn = LOSS_TYPE[self._args['loss']](loss_fingerprints, dist_fn)
train_metrics = MetricTracker('train', {
'all': top_k_acc(all_fingerprints, dist_fn, [1,5,10,50,100], pre='all')
'all': top_k_acc(all_fingerprints, dist_fn, [1,8,64], pre='all')
})
val_metrics = MetricTracker('val', {
'all': top_k_acc(all_fingerprints, dist_fn, [1,5,10,50,100], pre='all'),
'val': top_k_acc(val_fingerprints, dist_fn, [1,5,10,50,100], pre='val'),
'all': top_k_acc(all_fingerprints, dist_fn, [1,8,64], pre='all'),
# 'val': top_k_acc(val_fingerprints, dist_fn, [1,5,10,50,100], pre='val'),
})
best_loss = None
......@@ -438,6 +480,9 @@ class VoxelNet(LeadoptModel):
val_metrics.update('loss', loss)
val_metrics.evaluate(predicted_fp, correct_fp)
if checkpoint_callback:
checkpoint_callback(self, epoch)
val_metrics.normalize(self._args['test_steps'])
self._log.log(val_metrics.get_all())
......@@ -451,15 +496,15 @@ class VoxelNet(LeadoptModel):
val_metrics.clear()
def run_test(self, save_path):
def run_test(self, save_path, use_val=False):
# load test dataset
test_dat = FragmentDataset(
self._args['fragments'],
rec_typer=REC_TYPER[self._args['rec_typer']],
lig_typer=LIG_TYPER[self._args['lig_typer']],
# filter_rec=partitions.TEST,
filter_rec=set(get_bios(moad_partitions.TEST)),
filter_smi=set(moad_partitions.TEST_SMI),
filter_rec=set(get_bios(moad_partitions.VAL if use_val else moad_partitions.TEST)),
filter_smi=set(moad_partitions.VAL_SMI if use_val else moad_partitions.TEST_SMI),
fdist_min=self._args['fdist_min'],
fdist_max=self._args['fdist_max'],
fmass_min=self._args['fmass_min'],
......@@ -509,8 +554,12 @@ class VoxelNet(LeadoptModel):
example_idx, sample_idx = batch[j]
predicted_fp[example_idx][sample_idx] = predicted[j].detach().cpu().numpy()
np.save(os.path.join(save_path, 'predicted_fp.npy'), predicted_fp)
np.save(os.path.join(save_path, 'correct_fp.npy'), correct_fp)
if use_val:
np.save(os.path.join(save_path, 'val_predicted_fp.npy'), predicted_fp)
np.save(os.path.join(save_path, 'val_correct_fp.npy'), correct_fp)
else:
np.save(os.path.join(save_path, 'predicted_fp.npy'), predicted_fp)
np.save(os.path.join(save_path, 'correct_fp.npy'), correct_fp)
print('done.')
......
......@@ -3,11 +3,20 @@ import torch
import torch.nn as nn
import torch.nn.functional as F
Flatten = None
try:
Flatten = nn.Flatten
except:
from . import backport
Flatten = backport.Flatten
class VoxelFingerprintNet(nn.Module):
def __init__(self, in_channels, output_size, blocks=[32,64], fc=[2048], pad=True):
super(VoxelFingerprintNet, self).__init__()
blocks = list(blocks)
fc = list(fc)
self.blocks = nn.ModuleList()
prev = in_channels
for i in range(len(blocks)):
......@@ -32,7 +41,7 @@ class VoxelFingerprintNet(nn.Module):
self.reduce = nn.Sequential(
nn.AdaptiveAvgPool3d((1,1,1)),
nn.Flatten(),
Flatten(),
)
pred = []
......
......@@ -29,6 +29,13 @@ def rdkfingerprint(m):
return n_fp
def rdkfingerprint10(m):
'''rdkfingerprint as 2048-len bit array (maxPath=10)'''
fp = Chem.rdmolops.RDKFingerprint(m, maxPath=10)
n_fp = list(map(int, list(fp.ToBitString())))
return n_fp
def morganfingerprint(m):
'''morgan fingerprint as 2048-len bit array'''
m.UpdatePropertyCache(strict=False)
......@@ -49,6 +56,7 @@ def gobbi2d(m):
FINGERPRINTS = {
'rdk': (rdkfingerprint, 2048),
'rdk10': (rdkfingerprint10, 2048),
'morgan': (morganfingerprint, 2048),
'gobbi2d': (gobbi2d, 2048),
}
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment