Commit 02de11c6 authored by hgarrereyn's avatar hgarrereyn
Browse files

simplify metric tracking, add inside_support metric

parent 2608cb67
......@@ -77,7 +77,7 @@ def average_support(fingerprints, fn):
dist -= p_dist[i]
dist *= -1
dist_n = F.sigmoid(dist)
dist_n = torch.sigmoid(dist)
c[i] = torch.mean(dist_n)
......@@ -88,6 +88,36 @@ def average_support(fingerprints, fn):
return _average_support
def inside_support(fingerprints, fn):
"""
"""
def _inside_support(yp, yt):
# correct distance
p_dist = broadcast_fn(fn, yp, yt)
c = torch.empty(yp.shape[0])
for i in range(yp.shape[0]):
# compute distance to all other fragments
dist = broadcast_fn(fn, yp[i].unsqueeze(0), fingerprints)
# shift distance so bad examples are positive
dist -= p_dist[i]
dist *= -1
# ignore labels that are further away
dist[dist < 0] = 0
dist_n = torch.sigmoid(dist)
c[i] = torch.mean(dist_n)
score = torch.mean(c)
return score
return _inside_support
def top_k_acc(fingerprints, fn, k, pre=''):
"""Top-k accuracy metric.
......
......@@ -12,7 +12,8 @@ from leadopt.models.voxel import VoxelFingerprintNet
from leadopt.data_util import FragmentDataset, FingerprintDataset, LIG_TYPER,\
REC_TYPER
from leadopt.grid_util import get_batch
from leadopt.metrics import mse, bce, tanimoto, cos, top_k_acc, average_support
from leadopt.metrics import mse, bce, tanimoto, cos, top_k_acc,\
average_support, inside_support
from config import partitions
......@@ -40,7 +41,10 @@ LOSS_TYPE = {
'direct': _direct_loss,
# minimize distance to target and maximize distance to all other
'support_v1': average_support
'support_v1': average_support,
# support, limited to closer points
'support_v2': average_support,
}
......@@ -69,10 +73,15 @@ class RunLog(object):
class MetricTracker(object):
def __init__(self, name):
def __init__(self, name, metric_fns):
self._name = name
self._metric_fns = metric_fns
self._metrics = {}
def evaluate(self, yp, yt):
for m in self._metric_fns:
self.update(m, self._metric_fns[m](yp, yt))
def update(self, name, metric):
if type(metric) is dict:
for subname in metric:
......@@ -305,12 +314,13 @@ class VoxelNet(LeadoptModel):
loss_fn = LOSS_TYPE[self._args['loss']](loss_fingerprints, dist_fn)
metrics = {
train_metrics = MetricTracker('train', {
'all': top_k_acc(all_fingerprints, dist_fn, [1,5,10,50,100], pre='all')
}
train_metrics = MetricTracker('train')
val_metrics = MetricTracker('val')
})
val_metrics = MetricTracker('val', {
'all': top_k_acc(all_fingerprints, dist_fn, [1,5,10,50,100], pre='all'),
'val': top_k_acc(val_fingerprints, dist_fn, [1,5,10,50,100], pre='val'),
})
best_loss = None
......@@ -346,9 +356,7 @@ class VoxelNet(LeadoptModel):
opt.step()
train_metrics.update('loss', loss)
for m in metrics:
train_metrics.update(m, metrics[m](predicted_fp, correct_fp))
train_metrics.evaluate(predicted_fp, correct_fp)
train_metrics.normalize(self._args['batch_size'])
self._log.log(train_metrics.get_all())
train_metrics.clear()
......@@ -382,8 +390,7 @@ class VoxelNet(LeadoptModel):
loss = loss_fn(predicted_fp, correct_fp)
val_metrics.update('loss', loss)
for m in metrics:
val_metrics.update(m, metrics[m](predicted_fp, correct_fp))
val_metrics.evaluate(predicted_fp, correct_fp)
val_metrics.normalize(self._args['test_steps'] * self._args['batch_size'])
self._log.log(val_metrics.get_all())
......
Supports Markdown
0% or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment