-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #4 from prokube-ai/molecules
Molecules
- Loading branch information
Showing
20 changed files
with
523 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,40 @@ | ||
default: | ||
image: | ||
name: gcr.io/kaniko-project/executor:debug | ||
entrypoint: [""] | ||
|
||
variables: | ||
REGISTRY_HOME: "${CI_REGISTRY}/${CI_PROJECT_PATH}/" | ||
CI_DEBUG_TRACE: "true" | ||
|
||
cache: | ||
paths: | ||
- .cache/pip | ||
- .cache/kaniko | ||
- /kaniko | ||
|
||
stages: | ||
- build-image | ||
|
||
before_script: | ||
- mkdir -p /kaniko/.docker | ||
# adding our cert to kaniko's additional certs | ||
- cat ${CI_SERVER_TLS_CA_FILE} >> /kaniko/ssl/certs/additional-ca-cert-bundle.crt | ||
# Creating kaniko config | ||
- > | ||
echo "{\"auths\":{\"${CI_REGISTRY}\":{\"auth\":\"$(printf "%s:%s" "${CI_REGISTRY_USER}" "${CI_REGISTRY_PASSWORD}" | ||
| base64 | tr -d '\n')\"}}}" > /kaniko/.docker/config.json | ||
chem-util-build: | ||
stage: build-image | ||
script: | ||
- > | ||
/kaniko/executor | ||
--context "${CI_PROJECT_DIR}/images/molecules/" | ||
--dockerfile "${CI_PROJECT_DIR}/images/molecules/Dockerfile" | ||
--destination "${CI_REGISTRY}/${CI_PROJECT_PATH}/chem-util:latest" | ||
--destination "${CI_REGISTRY}/${CI_PROJECT_PATH}/chem-util:${CI_COMMIT_SHORT_SHA}" | ||
rules: | ||
- changes: | ||
- images/molecules/**/* |
Binary file not shown.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,2 @@ | ||
# Images | ||
This folder contains code and dockerfiles to create container images used elsewhere. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,14 @@ | ||
FROM mambaorg/micromamba:1.4.9 | ||
WORKDIR /app | ||
|
||
COPY --chown=$MAMBA_USER:$MAMBA_USER env.yaml /tmp/env.yaml | ||
RUN micromamba install -y -n base -f /tmp/env.yaml && \ | ||
micromamba clean --all --yes | ||
|
||
ARG MAMBA_DOCKERFILE_ACTIVATE=1 | ||
|
||
COPY . . | ||
|
||
USER root | ||
RUN export PATH="$PATH:/opt/conda/bin/" | ||
ENTRYPOINT ["/opt/conda/bin/python", "/app/chem-util.py"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,31 @@ | ||
# chem-util | ||
A simple python script that reads in a CSV with SMILES, calculates molecular fingerprints, trains a model and | ||
evaluates it. It is the basis for `pipelines/molecules` pipeline whre you will also find the main README. | ||
|
||
## Local run | ||
Help | ||
```shell | ||
chem-util.py --help | ||
``` | ||
|
||
All steps | ||
```shell | ||
mkdir /tmp/chem | ||
chem-util.py preprocess -i ../../data/ames.csv.zip -o /tmp/chem/processed.csv.zip | ||
chem-util.py split -i /tmp/chem/processed.csv.zip -o /tmp/chem/train.csv.zip -t /tmp/chem/test.csv.zip | ||
chem-util.py train -i /tmp/chem/train.csv.zip -o /tmp/chem/model.joblib | ||
chem-util.py evaluate -i /tmp/chem/test.csv.zip -m /chem/model.joblib | ||
``` | ||
|
||
## Local build | ||
```shell | ||
docker build --platform linux/amd64 . -t chem-util | ||
# Test run | ||
docker run chem-util --help | ||
# Jumping into the container with bound local folder | ||
docker run --entrypoint /bin/bash -it -v <local-path>:<container-path> chem-util | ||
``` | ||
|
||
# model-serving | ||
`model-serving.py` contains the code of KServe transformer and custom predictor. It can be deployed as KServe | ||
InferenceService by using [model.yaml](../../serving/molecules/model.yaml). |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,113 @@ | ||
import click | ||
import logging | ||
import pandas as pd | ||
from src.features import get_cfps | ||
from src.utils import mol2html | ||
from rdkit.Chem import PandasTools | ||
from sklearn.model_selection import train_test_split | ||
from sklearn.ensemble import RandomForestClassifier | ||
import joblib | ||
from sklearn.metrics import roc_auc_score | ||
|
||
|
||
logging.basicConfig(level=logging.INFO, format="%(asctime)s:%(levelname)s:%(name)s:%(message)s") | ||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@click.group() | ||
def cli(): | ||
pass | ||
|
||
|
||
@cli.command('preprocess') | ||
@click.option('--input-data', '-i', help="Path to the input data (csv.zip).", required=True, type=str) | ||
@click.option('--output-data', '-o', help="Path to the output.", required=True, type=str) | ||
@click.option('--fp-bits', '-n', help="Number of the fingerprint bits.", required=False, type=int, | ||
default=1024) | ||
@click.option('--id_col', '-d', help="Name of the ID col.", required=False, type=str, default='ID') | ||
@click.option('--target', '-t', help="Name of the target col.", required=False, type=str, default='class') | ||
@click.option('--sample', '-s', help="Path to where to store class samples.", required=False, type=str) | ||
def preprocess(input_data, output_data, fp_bits, id_col, target, sample): | ||
logger.info(f"Reading in {input_data}.") | ||
df = pd.read_csv(input_data, index_col=0, compression='zip') | ||
logger.info("Calculating features.") | ||
PandasTools.AddMoleculeColumnToFrame(df, smilesCol='Smiles') # adding mol objects | ||
fp_cols = [f'bit_{x}' for x in range(fp_bits)] | ||
df = df.join( | ||
pd.DataFrame( | ||
data=[get_cfps(row[1]['ROMol'], nBits=fp_bits) for row in df.iterrows()], | ||
columns=fp_cols | ||
) | ||
) | ||
logger.info(f"Storing to {output_data}.") | ||
df[[id_col, target] + fp_cols].to_csv(output_data, compression='zip') | ||
if sample: | ||
md_data = """# Sample molecules\n""" | ||
|
||
classes = df[target].unique() | ||
for c in classes: | ||
md_data += f"## Class {c}\n" | ||
m = df[df[target] == c].iloc[0]['ROMol'] | ||
md_data += mol2html(m, legend=f'Class: {c}') | ||
md_data += "\n" | ||
with open(sample, 'w') as f: | ||
f.write(md_data) | ||
|
||
|
||
@cli.command('split') | ||
@click.option('--input-data', '-i', help="Path to the input data.", required=True, type=str) | ||
@click.option('--output-train', '-o', help="Path to the train output.", required=True, type=str) | ||
@click.option('--output-test', '-t', help="Path to the test output.", required=True, type=str) | ||
@click.option('--test-fraction', '-f', help="Fraction of dataset to used for evaluation.", | ||
required=False, type=float, default=0.2) | ||
@click.option('--seed', '-f', help="Random seed.", | ||
required=False, type=int, default=42) | ||
def split(input_data, output_train, output_test, test_fraction, seed): | ||
if not(0.0 < test_fraction < 1.0): | ||
raise ValueError(f"test_fraction should be between 0 and 1. Provided was {test_fraction}") | ||
logger.info(f"Reading in {input_data}.") | ||
df = pd.read_csv(input_data, index_col=0, compression='zip') | ||
logger.info("Splitting.") | ||
train, test = train_test_split(list(range(len(df))), random_state=seed, test_size=test_fraction) | ||
logger.info(f"Storing to {output_train}.") | ||
df.iloc[train].to_csv(output_train, compression='zip') | ||
logger.info(f"Storing to {output_test}.") | ||
df.iloc[test].to_csv(output_test, compression='zip') | ||
|
||
|
||
@cli.command('train') | ||
@click.option('--input-data', '-i', help="Path to the training data.", required=True, type=str) | ||
@click.option('--output-model', '-o', help="Path to the output model.", required=True, type=str) | ||
@click.option('--target', '-t', help="Name of the target col.", required=False, type=str, default='class') | ||
@click.option('--n-trees', '-n', help="Number of trees.", required=False, type=int, default=16) | ||
def train(input_data, output_model, target, n_trees): | ||
logger.info(f"Reading in {input_data}.") | ||
df = pd.read_csv(input_data, index_col=0, compression='zip') | ||
fp_cols = [x for x in df.columns if 'bit_' in x] | ||
logger.info(f'Detected {len(fp_cols)} fingerprint columns.') | ||
logger.info(f'Fitting a RandomForestClassifier model with {n_trees} trees.') | ||
clf = RandomForestClassifier(n_estimators=n_trees) | ||
clf.fit(df[fp_cols], df[target]) | ||
logger.info(f'Saving model {output_model}') | ||
joblib.dump(clf, output_model) | ||
|
||
|
||
@cli.command('evaluate') | ||
@click.option('--input-data', '-i', help="Path to the test data.", required=True, type=str) | ||
@click.option('--input-model', '-m', help="Path to the model.", required=True, type=str) | ||
@click.option('--output-metrics', '-o', help="Filename where to store metrics.", required=False, type=str) | ||
def evaluate(input_data, input_model, output_metrics): | ||
logger.info(f"Reading in {input_data}.") | ||
df = pd.read_csv(input_data, index_col=0, compression='zip') | ||
logger.info(f"Reading in the model {input_model}.") | ||
clf = joblib.load(input_model) | ||
fp_cols = [x for x in df.columns if 'bit_' in x] | ||
score = roc_auc_score(df['class'].values, clf.predict_proba(df[fp_cols])[:, 1]) | ||
logger.info(f'Model roc auc score is: {score}.') | ||
if output_metrics: | ||
with open(output_metrics, 'w') as f: | ||
f.write(str(score)) | ||
|
||
|
||
if __name__ == '__main__': | ||
cli() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,12 @@ | ||
name: base | ||
channels: | ||
- conda-forge | ||
dependencies: | ||
- python=3.10.12 | ||
- pandas=2.0.3 | ||
- click=8.1.3 | ||
- numpy=1.25.2 | ||
- scikit-learn=1.3.0 | ||
- rdkit=2023.03.3 | ||
- pip: | ||
- kserve==0.10.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,47 @@ | ||
import kserve | ||
import click | ||
from src.serving import MolTransformer, MolPredictor | ||
|
||
DEFAULT_MODEL_NAME = "model" | ||
|
||
|
||
@click.group() | ||
def cli(): | ||
pass | ||
|
||
|
||
@cli.command('serve_transformer', context_settings=dict( | ||
ignore_unknown_options=True, | ||
allow_extra_args=True, | ||
)) | ||
@click.option('--model_name', default=DEFAULT_MODEL_NAME, | ||
help='The name that the model is served under.', type=str) | ||
@click.option('--predictor_host', help='The URL for the model predict function', required=True, type=str) | ||
@click.option('--n_bits', help='Number of bits to use for the fingerprint', required=False, type=int, | ||
default=1024) | ||
def serve_transformer(model_name, predictor_host, n_bits): | ||
transformer = MolTransformer( | ||
name=model_name, | ||
predictor_host=predictor_host, | ||
n_bits=n_bits | ||
) | ||
server = kserve.ModelServer() | ||
server.start(models=[transformer]) | ||
|
||
|
||
@cli.command('serve_predictor', context_settings=dict( | ||
ignore_unknown_options=True, | ||
allow_extra_args=True, | ||
)) | ||
@click.option('--model_name', default=DEFAULT_MODEL_NAME, | ||
help='The name that the model is served under.', type=str) | ||
def serve_predictor(model_name): | ||
predictor = MolPredictor( | ||
name=model_name, | ||
) | ||
server = kserve.ModelServer() | ||
server.start(models=[predictor]) | ||
|
||
|
||
if __name__ == "__main__": | ||
cli() |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
from rdkit import DataStructs | ||
from rdkit.Chem import AllChem | ||
from rdkit.Chem.rdchem import Mol | ||
import numpy as np | ||
|
||
|
||
def get_cfps(mol: Mol, radius: int = 1, nBits: int = 1024, useFeatures: bool = False, | ||
dtype: np.dtype = np.int8) -> np.ndarray: | ||
"""Calculates circular (Morgan) fingerprint. | ||
https://rdkit.org/docs/GettingStartedInPython.html#morgan-fingerprints-circular-fingerprints | ||
Parameters | ||
---------- | ||
mol : rdkit.Chem.rdchem.Mol | ||
radius : int | ||
Fingerprint radius | ||
nBits : int | ||
Length of hashed fingerprint (without descriptors) | ||
useFeatures : bool | ||
To get feature fingerprints (FCFP) instead of normal ones (ECFP), defaults to False | ||
dtype : np.dtype | ||
Numpy data type for the array. | ||
Returns | ||
------- | ||
np.ndarray | ||
np array with the fingerprint | ||
""" | ||
arr = np.zeros((1,), dtype) | ||
DataStructs.ConvertToNumpyArray( | ||
AllChem.GetMorganFingerprintAsBitVect(mol, radius, nBits=nBits, useFeatures=useFeatures), arr) | ||
return arr |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,48 @@ | ||
import joblib | ||
from kserve import Model, constants | ||
from kserve.errors import InferenceError, ModelMissingError | ||
from typing import Dict | ||
import logging | ||
from rdkit import Chem | ||
from .features import get_cfps | ||
import os | ||
|
||
|
||
logging.basicConfig(level=constants.KSERVE_LOGLEVEL) | ||
|
||
|
||
class MolTransformer(Model): | ||
def __init__(self, name: str, predictor_host: str, n_bits: int = 1024, headers: Dict[str, str] = None): | ||
super().__init__(name) | ||
self.predictor_host = predictor_host | ||
self.n_bits = n_bits | ||
self.ready = True | ||
|
||
def preprocess(self, inputs: Dict, headers: Dict[str, str] = None) -> Dict: | ||
return {'instances': [[x.item() for x in get_cfps(Chem.MolFromSmiles(instance), nBits=self.n_bits)] for | ||
instance in inputs['instances']]} | ||
|
||
def postprocess(self, inputs: Dict, headers: Dict[str, str] = None) -> Dict: | ||
return inputs | ||
|
||
|
||
class MolPredictor(Model): | ||
def __init__(self, name: str): | ||
super().__init__(name) | ||
self.name = name | ||
self.ready = False | ||
self.load() | ||
|
||
def load(self): | ||
self.model = joblib.load(f'/mnt/models/{os.environ["MODEL_FILENAME"]}') | ||
self.ready = True | ||
print(f"Loaded {self.model.__str__()}") | ||
return self.ready | ||
|
||
def predict(self, payload: Dict, headers: Dict[str, str] = None) -> Dict: | ||
instances = payload["instances"] | ||
try: | ||
result = self.model.predict(instances).tolist() | ||
return {"predictions": result} | ||
except Exception as e: | ||
raise InferenceError(str(e)) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,10 @@ | ||
from rdkit.Chem import Draw | ||
from base64 import b64encode | ||
MOL_SIZE = (200, 200) | ||
|
||
|
||
def mol2html(mol, legend=""): | ||
data = Draw._moltoimg(mol, MOL_SIZE, [], legend=legend, returnPNG=True, kekulize=True) | ||
b64image = b64encode(data).decode('ascii') | ||
return (f'<div style="width: {MOL_SIZE[0]}px; height: {MOL_SIZE[1]}px" data-content="rdkit/molecule">' | ||
f'<img src="data:image/png;base64,{b64image}" alt="Mol"/></div>') |
Oops, something went wrong.