Commit 1c333a36 authored by Harrison Green's avatar Harrison Green
Browse files

add topk acc metrics, refactor model, support expanded label training

parent 5b57c647
......@@ -122,6 +122,7 @@ def average_support(fingerprints, fn):
return g
def average_support_mse(fingerprints):
def fn(yp, yt):
return torch.sum((yp - yt) ** 2, axis=1)
......@@ -160,6 +161,7 @@ def average_support_weighted(fingerprints, fn):
return g
def average_support_weighted_mse(fingerprints):
def fn(yp, yt, att):
......@@ -210,127 +212,61 @@ def average_position_weighted_mse(fingerprints, norm):
return average_position_weighted(fingerprints, fn, norm)
class Model(object):
def __init__(self):
pass
def top_k_acc(t_fingerprints, fn, k, pre=''):
def setup_parser(self, name, parser):
pass
def fn_br(yp,yt):
yp_b, yt_b = torch.broadcast_tensors(yp, yt)
return fn(yp_b.detach(), yt_b.detach())
def build_model(self, args):
pass
def g(yp, yt):
# correct distance
p_dist = fn_br(yp, yt)
def get_metrics(self, args, std, mean, train_dat, test_dat):
'''Returns a dict of metrics'''
pass
c = torch.empty(yp.shape[0], len(k))
for i in range(yp.shape[0]):
# compute distance to all other fragments
dist = fn_br(yp[i].unsqueeze(0), t_fingerprints)
def get_train_mode(self):
pass
# number of fragment that are closer or equal
count = torch.sum((dist < p_dist[i]).to(torch.float))
class Voxel1(Model):
for j in range(len(k)):
c[i,j] = int(count < k[j])
score = torch.mean(c, axis=0)
def setup_parser(self, name, parser):
sub = parser.add_parser(name)
sub.add_argument('--in_channels', type=int, default=18)
sub.add_argument('--output_size', type=int, default=2048)
sub.add_argument('--sigmoid', default=False, action='store_true')
sub.add_argument('--f1', type=int, default=32)
sub.add_argument('--f2', type=int, default=64)
sub.add_argument('--f3', type=int, default=128)
m = {'%sacc_%d' % (pre,h):v.item() for h,v in zip(k,score)}
def build_model(self, args):
m = VoxelFingerprintNet(
in_channels=args.in_channels,
output_size=args.output_size,
sigmoid=args.sigmoid,
f1=args.f1,
f2=args.f2,
f3=args.f3
).cuda()
return m
def get_metrics(self, args, std, mean, train_dat, test_dat):
'''Returns a dict of metrics'''
return {
'tanimoto': tanimoto
}
return g
def top_k_acc_mse(t_fingerprints, k, pre):
def fn(yp, yt):
return torch.sum((yp - yt) ** 2, axis=1)
return top_k_acc(t_fingerprints, fn, k, pre)
def get_train_mode(self):
return train
class Voxel2(Model):
class Model(object):
def __init__(self):
pass
def setup_parser(self, name, parser):
sub = parser.add_parser(name)
sub.add_argument('--in_channels', type=int, default=18)
sub.add_argument('--output_size', type=int, default=2048)
sub.add_argument('--sigmoid', default=False, action='store_true')
sub.add_argument('--batchnorm', default=False, action='store_true')
sub.add_argument('--f1', type=int, default=32)
sub.add_argument('--f2', type=int, default=64)
pass
def build_model(self, args):
m = VoxelFingerprintNet2(
in_channels=args.in_channels,
output_size=args.output_size,
sigmoid=args.sigmoid,
batchnorm=args.batchnorm,
f1=args.f1,
f2=args.f2
).cuda()
return m
pass
def get_metrics(self, args, std, mean, train_dat, test_dat):
all_fp = list(set(train_dat.valid_fingerprints + test_dat.valid_fingerprints))
fingerprints = train_dat.fingerprints['fingerprint_data'][all_fp]
train_fingerprints = train_dat.fingerprints['fingerprint_data'][train_dat.valid_fingerprints]
test_fingerprints = train_dat.fingerprints['fingerprint_data'][test_dat.valid_fingerprints]
metrics = {}
if args.fingerprints == 'fingerprints.h5':
metrics.update({
'within_1_mass': denorm(within_k_mass(1), std, mean),
'within_10_mass': denorm(within_k_mass(10), std, mean),
'within_50_mass': denorm(within_k_mass(50), std, mean),
'within_100_mass': denorm(within_k_mass(100), std, mean),
})
POS_METRICS = [
('mse', average_position_mse),
('cos', average_position_cos),
]
if args.sigmoid:
POS_METRICS += [
('bce', average_position_bce),
('tanimoto', average_position_tanimoto),
]
for name, eval_fn in POS_METRICS:
metrics.update({
'pos_%s' % name: eval_fn(fingerprints, norm=True),
'pos_%s_raw' % name: eval_fn(fingerprints, norm=False),
})
return metrics
def get_loss(self, args, std, mean, train_dat, test_dat):
all_fp = list(set(train_dat.valid_fingerprints + test_dat.valid_fingerprints))
fingerprints = train_dat.fingerprints['fingerprint_data'][all_fp]
loss_fn = average_support_mse(fingerprints)
return loss_fn
'''Returns a dict of metrics'''
pass
def get_train_mode(self):
return train
pass
class Voxel2b(Model):
class VoxelNet(Model):
def setup_parser(self, name, parser):
sub = parser.add_parser(name)
......@@ -340,6 +276,8 @@ class Voxel2b(Model):
sub.add_argument('--blocks', nargs='+', type=int, default=[32,64])
sub.add_argument('--fc', nargs='+', type=int, default=[2048])
sub.add_argument('--use_all_labels', default=False, action='store_true')
def build_model(self, args):
m = VoxelFingerprintNet2b(
in_channels=args.in_channels,
......@@ -358,7 +296,13 @@ class Voxel2b(Model):
train_fingerprints = train_dat.fingerprints['fingerprint_data'][train_dat.valid_fingerprints]
test_fingerprints = train_dat.fingerprints['fingerprint_data'][test_dat.valid_fingerprints]
metrics = {}
t_all_fp = torch.Tensor(fingerprints).cuda()
t_test_fp = torch.Tensor(test_fingerprints).cuda()
metrics = {
'acc': top_k_acc_mse(t_all_fp, k=[1,5,10,50,100], pre='all'),
'acc2': top_k_acc_mse(t_test_fp, k=[1,5,10,50,100], pre='test'),
}
POS_METRICS = [
('mse', average_position_mse),
......@@ -374,7 +318,11 @@ class Voxel2b(Model):
def get_loss(self, args, std, mean, train_dat, test_dat):
all_fp = list(set(train_dat.valid_fingerprints))
all_fp = set(train_dat.valid_fingerprints)
if args.use_all_labels:
all_fp |= set(test_dat.valid_fingerprints)
all_fp = list(all_fp)
fingerprints = train_dat.fingerprints['fingerprint_data'][all_fp]
loss_fn = average_support_mse(fingerprints)
......@@ -540,9 +488,7 @@ class Skip1(Model):
MODELS = {
'voxel1': Voxel1(),
'voxel2': Voxel2(),
'voxel2b': Voxel2b(),
'voxelnet': VoxelNet(),
'voxel3': Voxel3(),
'voxel4': Voxel4(),
'latent1': Latent1(),
......
......@@ -87,13 +87,13 @@ def train(model, run_path, train_dat, test_dat, metrics, loss_fn, args):
calc_metrics = {'loss': loss}
for m in metrics:
calc_metrics[m] = metrics[m](y, fp)
# for m in metrics:
# wm = 'weighted_%s' % m
# calc_metrics[wm] = 0
# for i in range(args.batch_size):
# calc_metrics[wm] += metrics[m](y[i].unsqueeze(0), fp[i].unsqueeze(0)) * w[i]
# calc_metrics[wm] /= args.batch_size
r = metrics[m](y, fp)
# dict: multiple values
if type(r) is dict:
calc_metrics.update(r)
else:
calc_metrics[m] = r
wandb.log(calc_metrics)
......@@ -144,7 +144,16 @@ def train(model, run_path, train_dat, test_dat, metrics, loss_fn, args):
val_metrics['val_loss'] += loss
for m in metrics:
val_metrics['val_%s' % m] += metrics[m](y, fp)
r = metrics[m](y, fp)
# dict: multiple values
if type(r) is dict:
calc_metrics.update({
'val_%s' % k: r[k] for k in r
})
else:
calc_metrics[m] = r
# for m in metrics:
# wm = 'val_weighted_%s' % m
# t = 0
......
......@@ -13,14 +13,9 @@ from leadopt.data_util import FragmentDataset
from config import partitions
from models import MODELS
from leadopt.model_conf import MODELS
# LOSS = {
# 'bce': nn.BCELoss(),
# 'mse': nn.MSELoss(),
# }
def main():
parser = argparse.ArgumentParser()
......@@ -125,7 +120,6 @@ def main():
metrics = MODELS[args.version].get_metrics(args, std, mean, train_dat, test_dat)
# create loss function
# loss_fn = LOSS[args.loss]
loss_fn = MODELS[args.version].get_loss(args, std, mean, train_dat, test_dat)
train_func = MODELS[args.version].get_train_mode()
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment