Commit 4d81a065 authored by jdurrant's avatar jdurrant
Browse files

Added license text to files.

parent d6e0f779
Pipeline #323 failed with stages
in 0 seconds
...@@ -172,27 +172,3 @@ any liability incurred by, or claims asserted against, such Contributor by ...@@ -172,27 +172,3 @@ any liability incurred by, or claims asserted against, such Contributor by
reason of your accepting any such warranty or additional liability. reason of your accepting any such warranty or additional liability.
_END OF TERMS AND CONDITIONS_ _END OF TERMS AND CONDITIONS_
### APPENDIX: How to apply the Apache License to your work
To apply the Apache License to your work, attach the following boilerplate
notice, with the fields enclosed by brackets `[]` replaced with your own
identifying information. (Don't include the brackets!) The text should be
enclosed in the appropriate comment syntax for the file format. We also
recommend that a file or class name and description of purpose be included on
the same "printed page" as the copyright notice for easier identification
within third-party archives.
Copyright [yyyy] [name of copyright owner]
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
This diff is collapsed.
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
""" """
Contains utility code for reading packed data files. Contains utility code for reading packed data files.
""" """
...@@ -10,15 +25,15 @@ import h5py ...@@ -10,15 +25,15 @@ import h5py
import tqdm import tqdm
# Atom typing # Atom typing
# #
# Atom typing is the process of figuring out which layer each atom should be # Atom typing is the process of figuring out which layer each atom should be
# written to. For ease of testing, the packed data file contains a lot of # written to. For ease of testing, the packed data file contains a lot of
# potentially useful atomic information which can be distilled during the # potentially useful atomic information which can be distilled during the
# data loading process. # data loading process.
# #
# Atom typing is implemented by map functions of the type: # Atom typing is implemented by map functions of the type:
# (atom descriptor) -> (layer index) # (atom descriptor) -> (layer index)
# #
# If the layer index is -1, the atom is ignored. # If the layer index is -1, the atom is ignored.
...@@ -139,13 +154,13 @@ LIG_TYPER = { ...@@ -139,13 +154,13 @@ LIG_TYPER = {
class FragmentDataset(Dataset): class FragmentDataset(Dataset):
"""Utility class to work with the packed fragments.h5 format.""" """Utility class to work with the packed fragments.h5 format."""
def __init__(self, fragment_file, rec_typer=REC_TYPER['simple'], def __init__(self, fragment_file, rec_typer=REC_TYPER['simple'],
lig_typer=LIG_TYPER['simple'], filter_rec=None, filter_smi=None, lig_typer=LIG_TYPER['simple'], filter_rec=None, filter_smi=None,
fdist_min=None, fdist_max=None, fmass_min=None, fmass_max=None, fdist_min=None, fdist_max=None, fmass_min=None, fmass_max=None,
verbose=False, lazy_loading=True): verbose=False, lazy_loading=True):
"""Initializes the fragment dataset. """Initializes the fragment dataset.
Args: Args:
fragment_file: path to fragments.h5 fragment_file: path to fragments.h5
rec_typer: AtomTyper for receptor rec_typer: AtomTyper for receptor
...@@ -174,11 +189,11 @@ class FragmentDataset(Dataset): ...@@ -174,11 +189,11 @@ class FragmentDataset(Dataset):
def _load_rec(self, fragment_file, rec_typer): def _load_rec(self, fragment_file, rec_typer):
"""Loads receptor information.""" """Loads receptor information."""
f = h5py.File(fragment_file, 'r') f = h5py.File(fragment_file, 'r')
rec_coords = f['rec_coords'][()] rec_coords = f['rec_coords'][()]
rec_types = f['rec_types'][()] rec_types = f['rec_types'][()]
rec_lookup = f['rec_lookup'][()] rec_lookup = f['rec_lookup'][()]
r = range(len(rec_types)) r = range(len(rec_types))
if self.verbose: if self.verbose:
r = tqdm.tqdm(r, desc='Remap receptor atoms') r = tqdm.tqdm(r, desc='Remap receptor atoms')
...@@ -194,7 +209,7 @@ class FragmentDataset(Dataset): ...@@ -194,7 +209,7 @@ class FragmentDataset(Dataset):
rec_mapping = {} rec_mapping = {}
for i in range(len(rec_lookup)): for i in range(len(rec_lookup)):
rec_mapping[rec_lookup[i][0].decode('ascii')] = i rec_mapping[rec_lookup[i][0].decode('ascii')] = i
rec = { rec = {
'rec_coords': rec_coords, 'rec_coords': rec_coords,
'rec_types': rec_types, 'rec_types': rec_types,
...@@ -203,11 +218,11 @@ class FragmentDataset(Dataset): ...@@ -203,11 +218,11 @@ class FragmentDataset(Dataset):
'rec_mapping': rec_mapping, 'rec_mapping': rec_mapping,
'rec_loaded': rec_loaded 'rec_loaded': rec_loaded
} }
f.close() f.close()
return rec return rec
def _load_fragments(self, fragment_file, lig_typer): def _load_fragments(self, fragment_file, lig_typer):
"""Loads fragment information.""" """Loads fragment information."""
f = h5py.File(fragment_file, 'r') f = h5py.File(fragment_file, 'r')
...@@ -227,12 +242,12 @@ class FragmentDataset(Dataset): ...@@ -227,12 +242,12 @@ class FragmentDataset(Dataset):
# unpack frag data into separate structures # unpack frag data into separate structures
frag_coords = frag_data[:,:3].astype(np.float32) frag_coords = frag_data[:,:3].astype(np.float32)
frag_types = frag_data[:,3].astype(np.uint8) frag_types = frag_data[:,3].astype(np.uint8)
frag_remapped = np.zeros(len(frag_types), dtype=np.uint16) frag_remapped = np.zeros(len(frag_types), dtype=np.uint16)
if not self._lazy_loading: if not self._lazy_loading:
for i in range(len(frag_types)): for i in range(len(frag_types)):
frag_remapped[i] = lig_typer.apply(frag_types[i]) frag_remapped[i] = lig_typer.apply(frag_types[i])
frag_loaded = np.zeros(len(frag_lookup)).astype(np.bool) frag_loaded = np.zeros(len(frag_lookup)).astype(np.bool)
# find and save connection point # find and save connection point
...@@ -244,7 +259,7 @@ class FragmentDataset(Dataset): ...@@ -244,7 +259,7 @@ class FragmentDataset(Dataset):
for i in r: for i in r:
_,f_start,f_end,_,_ = frag_lookup[i] _,f_start,f_end,_,_ = frag_lookup[i]
fdat = frag_data[f_start:f_end] fdat = frag_data[f_start:f_end]
found = False found = False
for j in range(len(fdat)): for j in range(len(fdat)):
if fdat[j][3] == 0: if fdat[j][3] == 0:
...@@ -253,7 +268,7 @@ class FragmentDataset(Dataset): ...@@ -253,7 +268,7 @@ class FragmentDataset(Dataset):
break break
assert found, "missing fragment connection point at %d" % i assert found, "missing fragment connection point at %d" % i
frag = { frag = {
'frag_coords': frag_coords, # d_idx -> (x,y,z) 'frag_coords': frag_coords, # d_idx -> (x,y,z)
'frag_types': frag_types, # d_idx -> (type) 'frag_types': frag_types, # d_idx -> (type)
...@@ -267,7 +282,7 @@ class FragmentDataset(Dataset): ...@@ -267,7 +282,7 @@ class FragmentDataset(Dataset):
'frag_lig_idx': frag_lig_idx, 'frag_lig_idx': frag_lig_idx,
'frag_loaded': frag_loaded 'frag_loaded': frag_loaded
} }
f.close() f.close()
return frag return frag
...@@ -275,14 +290,14 @@ class FragmentDataset(Dataset): ...@@ -275,14 +290,14 @@ class FragmentDataset(Dataset):
def _get_valid_examples(self, filter_rec, filter_smi, fdist_min, fdist_max, fmass_min, def _get_valid_examples(self, filter_rec, filter_smi, fdist_min, fdist_max, fmass_min,
fmass_max, verbose): fmass_max, verbose):
"""Returns an array of valid fragment indexes. """Returns an array of valid fragment indexes.
"Valid" in this context means the fragment belongs to a receptor in "Valid" in this context means the fragment belongs to a receptor in
filter_rec and the fragment abides by the optional mass/distance filter_rec and the fragment abides by the optional mass/distance
constraints. constraints.
""" """
# keep track of valid examples # keep track of valid examples
valid_mask = np.ones(self.frag['frag_lookup'].shape[0]).astype(np.bool) valid_mask = np.ones(self.frag['frag_lookup'].shape[0]).astype(np.bool)
num_frags = self.frag['frag_lookup'].shape[0] num_frags = self.frag['frag_lookup'].shape[0]
# filter by receptor id # filter by receptor id
...@@ -298,7 +313,7 @@ class FragmentDataset(Dataset): ...@@ -298,7 +313,7 @@ class FragmentDataset(Dataset):
if rec in filter_rec: if rec in filter_rec:
valid_rec[i] = 1 valid_rec[i] = 1
valid_mask *= valid_rec valid_mask *= valid_rec
# filter by ligand smiles string # filter by ligand smiles string
if filter_smi is not None: if filter_smi is not None:
valid_lig = np.zeros(num_frags, dtype=np.bool) valid_lig = np.zeros(num_frags, dtype=np.bool)
...@@ -318,10 +333,10 @@ class FragmentDataset(Dataset): ...@@ -318,10 +333,10 @@ class FragmentDataset(Dataset):
# filter by fragment distance # filter by fragment distance
if fdist_min is not None: if fdist_min is not None:
valid_mask[self.frag['frag_dist'] < fdist_min] = 0 valid_mask[self.frag['frag_dist'] < fdist_min] = 0
if fdist_max is not None: if fdist_max is not None:
valid_mask[self.frag['frag_dist'] > fdist_max] = 0 valid_mask[self.frag['frag_dist'] > fdist_max] = 0
# filter by fragment mass # filter by fragment mass
if fmass_min is not None: if fmass_min is not None:
valid_mask[self.frag['frag_mass'] < fmass_min] = 0 valid_mask[self.frag['frag_mass'] < fmass_min] = 0
...@@ -337,10 +352,10 @@ class FragmentDataset(Dataset): ...@@ -337,10 +352,10 @@ class FragmentDataset(Dataset):
def __len__(self): def __len__(self):
"""Returns the number of valid fragment examples.""" """Returns the number of valid fragment examples."""
return self.valid_idx.shape[0] return self.valid_idx.shape[0]
def __getitem__(self, idx): def __getitem__(self, idx):
"""Returns the Nth example. """Returns the Nth example.
Returns a dict with: Returns a dict with:
f_coords: fragment coordinates (Fx3) f_coords: fragment coordinates (Fx3)
f_types: fragment layers (Fx1) f_types: fragment layers (Fx1)
...@@ -354,23 +369,23 @@ class FragmentDataset(Dataset): ...@@ -354,23 +369,23 @@ class FragmentDataset(Dataset):
# convert to fragment index # convert to fragment index
frag_idx = self.valid_idx[idx] frag_idx = self.valid_idx[idx]
return self.get_raw(frag_idx) return self.get_raw(frag_idx)
def get_raw(self, frag_idx): def get_raw(self, frag_idx):
# lookup fragment # lookup fragment
rec_id, f_start, f_end, p_start, p_end = self.frag['frag_lookup'][frag_idx] rec_id, f_start, f_end, p_start, p_end = self.frag['frag_lookup'][frag_idx]
smiles = self.frag['frag_smiles'][frag_idx].decode('ascii') smiles = self.frag['frag_smiles'][frag_idx].decode('ascii')
conn = self.frag['frag_conn'][frag_idx] conn = self.frag['frag_conn'][frag_idx]
# lookup receptor # lookup receptor
rec_idx = self.rec['rec_mapping'][rec_id.decode('ascii')] rec_idx = self.rec['rec_mapping'][rec_id.decode('ascii')]
_, r_start, r_end = self.rec['rec_lookup'][rec_idx] _, r_start, r_end = self.rec['rec_lookup'][rec_idx]
# fetch data # fetch data
# f_coords = self.frag['frag_coords'][f_start:f_end] # f_coords = self.frag['frag_coords'][f_start:f_end]
# f_types = self.frag['frag_types'][f_start:f_end] # f_types = self.frag['frag_types'][f_start:f_end]
p_coords = self.frag['frag_coords'][p_start:p_end] p_coords = self.frag['frag_coords'][p_start:p_end]
r_coords = self.rec['rec_coords'][r_start:r_end] r_coords = self.rec['rec_coords'][r_start:r_end]
if self._lazy_loading and self.frag['frag_loaded'][frag_idx] == 0: if self._lazy_loading and self.frag['frag_loaded'][frag_idx] == 0:
frag_types = self.frag['frag_types'] frag_types = self.frag['frag_types']
frag_remapped = self.frag['frag_remapped'] frag_remapped = self.frag['frag_remapped']
...@@ -468,29 +483,29 @@ class FingerprintDataset(Dataset): ...@@ -468,29 +483,29 @@ class FingerprintDataset(Dataset):
def _load_fingerprints(self, fingerprint_file): def _load_fingerprints(self, fingerprint_file):
"""Loads fingerprint information.""" """Loads fingerprint information."""
f = h5py.File(fingerprint_file, 'r') f = h5py.File(fingerprint_file, 'r')
fingerprint_data = f['fingerprints'][()] fingerprint_data = f['fingerprints'][()]
fingerprint_smiles = f['smiles'][()] fingerprint_smiles = f['smiles'][()]
# create smiles->idx mapping # create smiles->idx mapping
fingerprint_mapping = {} fingerprint_mapping = {}
for i in range(len(fingerprint_smiles)): for i in range(len(fingerprint_smiles)):
sm = fingerprint_smiles[i].decode('ascii') sm = fingerprint_smiles[i].decode('ascii')
fingerprint_mapping[sm] = i fingerprint_mapping[sm] = i
fingerprints = { fingerprints = {
'fingerprint_data': fingerprint_data, 'fingerprint_data': fingerprint_data,
'fingerprint_mapping': fingerprint_mapping, 'fingerprint_mapping': fingerprint_mapping,
'fingerprint_smiles': fingerprint_smiles, 'fingerprint_smiles': fingerprint_smiles,
} }
f.close() f.close()
return fingerprints return fingerprints
def for_smiles(self, smiles): def for_smiles(self, smiles):
"""Return a Tensor of fingerprints for a list of smiles. """Return a Tensor of fingerprints for a list of smiles.
Args: Args:
smiles: size N list of smiles strings (as str not bytes) smiles: size N list of smiles strings (as str not bytes)
""" """
......
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
""" """
Contains code for gpu-accelerated grid generation. Contains code for gpu-accelerated grid generation.
""" """
...@@ -120,15 +135,15 @@ def gpu_gridify(grid, atom_num, atom_coords, atom_mask, layer_offset, ...@@ -120,15 +135,15 @@ def gpu_gridify(grid, atom_num, atom_coords, atom_mask, layer_offset,
# invisible atoms # invisible atoms
if mask == 0: if mask == 0:
continue continue
# point radius squared # point radius squared
r = point_radius r = point_radius
r2 = point_radius * point_radius r2 = point_radius * point_radius
# quick cube bounds check # quick cube bounds check
if abs(fx-tx) > r2 or abs(fy-ty) > r2 or abs(fz-tz) > r2: if abs(fx-tx) > r2 or abs(fy-ty) > r2 or abs(fz-tz) > r2:
continue continue
# value to add to this gridpoint # value to add to this gridpoint
val = 0 val = 0
...@@ -147,7 +162,7 @@ def gpu_gridify(grid, atom_num, atom_coords, atom_mask, layer_offset, ...@@ -147,7 +162,7 @@ def gpu_gridify(grid, atom_num, atom_coords, atom_mask, layer_offset,
d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2 d2 = (fx-tx)**2 + (fy-ty)**2 + (fz-tz)**2
if d2 > r2: if d2 > r2:
continue continue
val = 1 val = 1
elif point_type == 2: # POINT_TYPE.CUBE elif point_type == 2: # POINT_TYPE.CUBE
# solid cube fill # solid cube fill
...@@ -290,7 +305,7 @@ def get_batch(data, batch_size=16, batch_set=None, width=48, res=0.5, ...@@ -290,7 +305,7 @@ def get_batch(data, batch_size=16, batch_set=None, width=48, res=0.5,
rot = fixed_rot rot = fixed_rot
if rot is None: if rot is None:
rot = rand_rot() rot = rand_rot()
if ignore_receptor: if ignore_receptor:
mol_gridify( mol_gridify(
cuda_grid, cuda_grid,
...@@ -420,5 +435,5 @@ def get_raw_batch(r_coords, r_types, p_coords, p_types, rec_typer, lig_typer, ...@@ -420,5 +435,5 @@ def get_raw_batch(r_coords, r_types, p_coords, p_types, rec_typer, lig_typer,
point_type=point_type, point_type=point_type,
acc_type=acc_type acc_type=acc_type
) )
return torch_grid return torch_grid
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os import os
......
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import torch import torch
import torch.nn as nn import torch.nn as nn
...@@ -36,7 +50,7 @@ def broadcast_fn(fn, yp, yt): ...@@ -36,7 +50,7 @@ def broadcast_fn(fn, yp, yt):
def average_position(fingerprints, fn, norm=True): def average_position(fingerprints, fn, norm=True):
"""Returns the average ranking of the correct fragment relative to all """Returns the average ranking of the correct fragment relative to all
possible fragments. possible fragments.
Args: Args:
fingerprints: NxF tensor of fingerprint data fingerprints: NxF tensor of fingerprint data
fn: distance function to compare fingerprints fn: distance function to compare fingerprints
...@@ -54,7 +68,7 @@ def average_position(fingerprints, fn, norm=True): ...@@ -54,7 +68,7 @@ def average_position(fingerprints, fn, norm=True):
# number of fragment that are closer or equal # number of fragment that are closer or equal
count = torch.sum((dist <= p_dist[i]).to(torch.float)) count = torch.sum((dist <= p_dist[i]).to(torch.float))
c[i] = count c[i] = count
score = torch.mean(c) score = torch.mean(c)
return score return score
...@@ -148,7 +162,7 @@ def top_k_acc(fingerprints, fn, k, pre=''): ...@@ -148,7 +162,7 @@ def top_k_acc(fingerprints, fn, k, pre=''):
for j in range(len(k)): for j in range(len(k)):
c[i,j] = int(count < k[j]) c[i,j] = int(count < k[j])
score = torch.mean(c, 0) score = torch.mean(c, 0)
m = {'%sacc_%d' % (pre, h): v.item() for h,v in zip(k,score)} m = {'%sacc_%d' % (pre, h): v.item() for h,v in zip(k,score)}
......
# Copyright 2021 Jacob Durrant
# Licensed under the Apache License, Version 2.0 (the "License"); you may not
# use this file except in compliance with the License. You may obtain a copy
# of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
# WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
# License for the specific language governing permissions and limitations
# under the License.
import os import os
import json import json
...@@ -59,7 +73,7 @@ LOSS_TYPE = { ...@@ -59,7 +73,7 @@ LOSS_TYPE = {
class RunLog(object): class RunLog(object):
def __init__(self, args, models, wandb_project=None): def __init__(self, args, models, wandb_project=None):
"""Initialize a run logger. """Initialize a run logger.
Args: Args:
args: command line training arguments args: command line training arguments
models: {name: model} mapping models: {name: model} mapping
...@@ -71,7 +85,7 @@ class RunLog(object): ...@@ -71,7 +85,7 @@ class RunLog(object):
project=wandb_project,