Tschoui commited on
Commit
8a0d2b8
·
1 Parent(s): de43f62

✨ Add prediction pipeline for random forest classifier

Browse files
Files changed (2) hide show
  1. app.py +45 -5
  2. predict.py +52 -0
app.py CHANGED
@@ -1,7 +1,47 @@
1
- from fastapi import FastAPI
 
 
 
2
 
3
- app = FastAPI()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
- @app.get("/")
6
- def greet_json():
7
- return {"Hello": "World!"}
 
1
+ """
2
+ This is the main entry point for the FastAPI application.
3
+ The app handles the request to predict toxicity for a list of SMILES strings.
4
+ """
5
 
6
+ #---------------------------------------------------------------------------------------
7
+ # Dependencies and global variable definition
8
+ import os
9
+ from typing import List, Dict, Optional
10
+ from fastapi import FastAPI, Header, HTTPException
11
+ from pydantic import BaseModel, Field
12
+
13
+ from predict import predict
14
+
15
+ API_KEY = os.getenv("API_KEY") # set via Space Secrets
16
+
17
+ #---------------------------------------------------------------------------------------
18
+ class Request(BaseModel):
19
+ smiles: List[str] = Field(min_items=1, max_items=1000)
20
+
21
+ class Response(BaseModel):
22
+ predictions: dict
23
+ model_info: Dict[str, str] = {}
24
+
25
+ app = FastAPI(title="toxicity-api")
26
+
27
+ @app.get("/metadata")
28
+ def metadata():
29
+ return {
30
+ "name": "AwesomeTox",
31
+ "version": "1.0.0",
32
+ "max_batch_size": 256,
33
+ "tox_endpoints": ["mutagenicity","hepatotoxicity"],
34
+ }
35
+
36
+ @app.get("/healthz")
37
+ def healthz():
38
+ return {"ok": True}
39
+
40
+ @app.post("/predict", response_model=Response)
41
+ def predict(request: Request, authorization: str = Header(default="")):
42
+ if not API_KEY or authorization != f"Bearer {API_KEY}":
43
+ raise HTTPException(status_code=401, detail="Unauthorized")
44
+
45
+ predictions = predict(request.smiles)
46
+ return {"predictions": predictions, "model_info": {"name":"random_clf", "version":"1.0.0"}}
47
 
 
 
 
predict.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ This files includes a predict function for the Tox21.
3
+ As an input it takes a list of SMILES and it outputs a nested dictionary with
4
+ SMILES and target names as keys.
5
+ """
6
+
7
+ #---------------------------------------------------------------------------------------
8
+ # Dependencies
9
+ from Literal import List
10
+ import random
11
+ from collections import defaultdict
12
+ #---------------------------------------------------------------------------------------
13
+ class Tox21RandomClassifier():
14
+ """
15
+ A random classifier that assigns a random toxicity score to a given SMILES string.
16
+ """
17
+
18
+ def __init__(self):
19
+
20
+ self.target_names = [
21
+ "NR-AR",
22
+ "NR-AR-LBD",
23
+ "NR-AhR",
24
+ "NR-Aromatase",
25
+ "NR-ER",
26
+ "NR-ER-LBD",
27
+ "NR-PPAR-gamma",
28
+ "SR-ARE",
29
+ "SR-ATAD5",
30
+ "SR-HSE",
31
+ "SR-MMP",
32
+ "SR-p53"
33
+ ]
34
+
35
+ def predict(self, smiles_list:List[str]) -> dict:
36
+ """
37
+ Predicts all Tox21 targets for a given list of SMILES strings by assigning
38
+ random toxicity scores.
39
+ """
40
+
41
+ predictions = defaultdict(dict)
42
+ for smiles in smiles_list:
43
+ for target in self.target_names:
44
+ predictions[smiles][target] = random.random()
45
+ return predictions
46
+
47
+ def predict(smiles_list: List[str]) -> dict:
48
+ """
49
+ Applies the classifier to a list of SMILES strings.
50
+ """
51
+ model = Tox21RandomClassifier()
52
+ return model.predict(smiles_list)