Tschoui commited on
Commit
4c1b5d1
Β·
1 Parent(s): 8a0d2b8

πŸ› Bug fixing

Browse files
__pycache__/app.cpython-311.pyc ADDED
Binary file (2.84 kB). View file
 
__pycache__/predict.cpython-311.pyc ADDED
Binary file (2.14 kB). View file
 
app.py CHANGED
@@ -10,7 +10,7 @@ from typing import List, Dict, Optional
10
  from fastapi import FastAPI, Header, HTTPException
11
  from pydantic import BaseModel, Field
12
 
13
- from predict import predict
14
 
15
  API_KEY = os.getenv("API_KEY") # set via Space Secrets
16
 
@@ -42,6 +42,6 @@ def predict(request: Request, authorization: str = Header(default="")):
42
  if not API_KEY or authorization != f"Bearer {API_KEY}":
43
  raise HTTPException(status_code=401, detail="Unauthorized")
44
 
45
- predictions = predict(request.smiles)
46
  return {"predictions": predictions, "model_info": {"name":"random_clf", "version":"1.0.0"}}
47
 
 
10
  from fastapi import FastAPI, Header, HTTPException
11
  from pydantic import BaseModel, Field
12
 
13
+ from predict import predict as predict_func
14
 
15
  API_KEY = os.getenv("API_KEY") # set via Space Secrets
16
 
 
42
  if not API_KEY or authorization != f"Bearer {API_KEY}":
43
  raise HTTPException(status_code=401, detail="Unauthorized")
44
 
45
+ predictions = predict_func(request.smiles)
46
  return {"predictions": predictions, "model_info": {"name":"random_clf", "version":"1.0.0"}}
47
 
predict.py CHANGED
@@ -6,7 +6,7 @@ SMILES and target names as keys.
6
 
7
  #---------------------------------------------------------------------------------------
8
  # Dependencies
9
- from Literal import List
10
  import random
11
  from collections import defaultdict
12
  #---------------------------------------------------------------------------------------
 
6
 
7
  #---------------------------------------------------------------------------------------
8
  # Dependencies
9
+ from typing import List
10
  import random
11
  from collections import defaultdict
12
  #---------------------------------------------------------------------------------------
requirements.txt CHANGED
@@ -1,2 +1,4 @@
1
  fastapi
2
  uvicorn[standard]
 
 
 
1
  fastapi
2
  uvicorn[standard]
3
+ pytest
4
+ httpx
tests/test_app.py ADDED
@@ -0,0 +1,100 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for the FastAPI application
3
+ """
4
+
5
+ import pytest
6
+ from fastapi.testclient import TestClient
7
+ from unittest.mock import patch
8
+ import os
9
+
10
+ from app import app
11
+
12
+
13
+ @pytest.fixture
14
+ def client():
15
+ return TestClient(app)
16
+
17
+
18
+ class TestMetadataEndpoint:
19
+
20
+ def test_metadata(self, client):
21
+ response = client.get("/metadata")
22
+ assert response.status_code == 200
23
+ data = response.json()
24
+
25
+ assert data["name"] == "AwesomeTox"
26
+ assert data["version"] == "1.0.0"
27
+ assert data["max_batch_size"] == 256
28
+ assert data["tox_endpoints"] == ["mutagenicity", "hepatotoxicity"]
29
+
30
+
31
+ class TestHealthzEndpoint:
32
+
33
+ def test_healthz(self, client):
34
+ response = client.get("/healthz")
35
+ assert response.status_code == 200
36
+ assert response.json() == {"ok": True}
37
+
38
+
39
+ class TestPredictEndpoint:
40
+
41
+ @patch.dict(os.environ, {"API_KEY": "test-key"})
42
+ def test_predict_with_valid_auth(self, client):
43
+ headers = {"Authorization": "Bearer test-key"}
44
+ data = {"smiles": ["CCO"]}
45
+
46
+ response = client.post("/predict", json=data, headers=headers)
47
+ assert response.status_code == 200
48
+
49
+ result = response.json()
50
+ assert "predictions" in result
51
+ assert "model_info" in result
52
+ assert result["model_info"]["name"] == "random_clf"
53
+ assert result["model_info"]["version"] == "1.0.0"
54
+
55
+ @patch.dict(os.environ, {"API_KEY": "test-key"})
56
+ def test_predict_without_auth(self, client):
57
+ data = {"smiles": ["CCO"]}
58
+
59
+ response = client.post("/predict", json=data)
60
+ assert response.status_code == 401
61
+ assert response.json()["detail"] == "Unauthorized"
62
+
63
+ @patch.dict(os.environ, {"API_KEY": "test-key"})
64
+ def test_predict_with_invalid_auth(self, client):
65
+ headers = {"Authorization": "Bearer wrong-key"}
66
+ data = {"smiles": ["CCO"]}
67
+
68
+ response = client.post("/predict", json=data, headers=headers)
69
+ assert response.status_code == 401
70
+
71
+ @patch.dict(os.environ, {"API_KEY": "test-key"})
72
+ def test_predict_empty_smiles_list(self, client):
73
+ headers = {"Authorization": "Bearer test-key"}
74
+ data = {"smiles": []}
75
+
76
+ response = client.post("/predict", json=data, headers=headers)
77
+ assert response.status_code == 422 # Validation error due to min_items=1
78
+
79
+ @patch.dict(os.environ, {"API_KEY": "test-key"})
80
+ def test_predict_too_many_smiles(self, client):
81
+ headers = {"Authorization": "Bearer test-key"}
82
+ data = {"smiles": ["CCO"] * 1001} # Exceeds max_items=1000
83
+
84
+ response = client.post("/predict", json=data, headers=headers)
85
+ assert response.status_code == 422 # Validation error due to max_items=1000
86
+
87
+ @patch.dict(os.environ, {"API_KEY": "test-key"})
88
+ def test_predict_multiple_smiles(self, client):
89
+ headers = {"Authorization": "Bearer test-key"}
90
+ data = {"smiles": ["CCO", "CCN", "CCC"]}
91
+
92
+ response = client.post("/predict", json=data, headers=headers)
93
+ assert response.status_code == 200
94
+
95
+ result = response.json()
96
+ predictions = result["predictions"]
97
+
98
+ for smiles in data["smiles"]:
99
+ assert smiles in predictions
100
+ assert len(predictions[smiles]) == 12
tests/test_predict.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Tests for the predict.py module
3
+ """
4
+
5
+ from predict import predict, Tox21RandomClassifier
6
+
7
+
8
+ class TestTox21RandomClassifier:
9
+
10
+ def test_init(self):
11
+ classifier = Tox21RandomClassifier()
12
+ assert len(classifier.target_names) == 12
13
+ expected_targets = [
14
+ "NR-AR", "NR-AR-LBD", "NR-AhR", "NR-Aromatase",
15
+ "NR-ER", "NR-ER-LBD", "NR-PPAR-gamma",
16
+ "SR-ARE", "SR-ATAD5", "SR-HSE", "SR-MMP", "SR-p53"
17
+ ]
18
+ assert classifier.target_names == expected_targets
19
+
20
+ def test_predict_single_smiles(self):
21
+ classifier = Tox21RandomClassifier()
22
+ smiles_list = ["CCO"]
23
+ result = classifier.predict(smiles_list)
24
+
25
+ assert "CCO" in result
26
+ assert len(result["CCO"]) == 12
27
+
28
+ for target in classifier.target_names:
29
+ assert target in result["CCO"]
30
+ assert 0 <= result["CCO"][target] <= 1
31
+
32
+ def test_predict_multiple_smiles(self):
33
+ classifier = Tox21RandomClassifier()
34
+ smiles_list = ["CCO", "CCN", "CCC"]
35
+ result = classifier.predict(smiles_list)
36
+
37
+ assert len(result) == 3
38
+ for smiles in smiles_list:
39
+ assert smiles in result
40
+ assert len(result[smiles]) == 12
41
+
42
+ for target in classifier.target_names:
43
+ assert target in result[smiles]
44
+ assert 0 <= result[smiles][target] <= 1
45
+
46
+ def test_predict_empty_list(self):
47
+ classifier = Tox21RandomClassifier()
48
+ result = classifier.predict([])
49
+ assert result == {}
50
+
51
+
52
+ class TestPredictFunction:
53
+
54
+ def test_predict_function(self):
55
+ smiles_list = ["CCO", "CCN"]
56
+ result = predict(smiles_list)
57
+
58
+ assert len(result) == 2
59
+ for smiles in smiles_list:
60
+ assert smiles in result
61
+ assert len(result[smiles]) == 12
62
+
63
+ def test_predict_function_empty(self):
64
+ result = predict([])
65
+ assert result == {}