ag added service API using docker and fastAPI
This commit is contained in:
18
service/Dockerfile
Normal file
18
service/Dockerfile
Normal file
@@ -0,0 +1,18 @@
|
||||
|
||||
#
|
||||
FROM python:3.9
|
||||
|
||||
#
|
||||
WORKDIR /code
|
||||
|
||||
#
|
||||
COPY ./requirements.txt /code/requirements.txt
|
||||
|
||||
#
|
||||
RUN pip install --no-cache-dir --upgrade -r /code/requirements.txt
|
||||
|
||||
#
|
||||
COPY ./app /code/app
|
||||
|
||||
#
|
||||
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "80"]
|
||||
0
service/app/__init__.py
Normal file
0
service/app/__init__.py
Normal file
27
service/app/main.py
Normal file
27
service/app/main.py
Normal file
@@ -0,0 +1,27 @@
|
||||
from typing import Union
|
||||
from .predict import predict
|
||||
from fastapi import FastAPI
|
||||
import numpy as np
|
||||
app = FastAPI()
|
||||
|
||||
|
||||
@app.get("/")
|
||||
def read_root():
|
||||
return {"Hello": "World"}
|
||||
|
||||
|
||||
@app.get("/predict/{data}")
|
||||
def read_item(data: str):
|
||||
final_json = {'probabilities':{}}
|
||||
|
||||
try:
|
||||
res = predict(data)
|
||||
for i in range(5):
|
||||
final_json['probabilities'][f"india_{i}"] = round(float(res[i]),2)
|
||||
final_json['success'] = True
|
||||
except Exception as e:
|
||||
final_json['success'] = False
|
||||
final_json['error'] = repr(e)
|
||||
|
||||
|
||||
return final_json
|
||||
BIN
service/app/metadata.pkl
Normal file
BIN
service/app/metadata.pkl
Normal file
Binary file not shown.
1
service/app/model.json
Normal file
1
service/app/model.json
Normal file
File diff suppressed because one or more lines are too long
60
service/app/predict.py
Normal file
60
service/app/predict.py
Normal file
@@ -0,0 +1,60 @@
|
||||
|
||||
|
||||
import json
|
||||
import xgboost as xgb
|
||||
import pickle
|
||||
import pandas as pd
|
||||
import numpy as np
|
||||
def predict(json_input:str):
|
||||
|
||||
with open('app/metadata.pkl','rb') as f:
|
||||
to_remove,use_small,evacuations,encoders = pickle.load(f)
|
||||
|
||||
|
||||
data = pd.DataFrame(json.loads(json_input[1:-1]),index=[0])
|
||||
|
||||
data.drop(columns=['dateandtime','skiarea_id','day_of_year','minute_of_day','year'], inplace=True, errors='ignore')
|
||||
|
||||
##evacuation_vehicles must be explicitated and workaround for other!
|
||||
evacuation = data.evacuation_vehicles.values[0]
|
||||
if isinstance(evacuation,str):
|
||||
evacuation = [evacuation]
|
||||
for c in evacuations:
|
||||
data[c] = False
|
||||
for c in evacuation:
|
||||
data[c] = True
|
||||
|
||||
for c in evacuation:
|
||||
if c not in evacuations:
|
||||
data['other'] = True
|
||||
break
|
||||
|
||||
data.drop(columns=['town','province','evacuation_vehicles'],inplace=True, errors='ignore')
|
||||
|
||||
|
||||
data['age'] = data['age'].astype(np.float32)
|
||||
|
||||
|
||||
|
||||
for c in data.columns:
|
||||
if c not in ['india','age','season','skiarea_name','destination']:
|
||||
data[c] = data[c].astype('str')
|
||||
if use_small:
|
||||
for c in to_remove.keys():
|
||||
for k in to_remove[c]:
|
||||
data.loc[data[c]==k,c] = 'other'
|
||||
for c in data.columns:
|
||||
if c not in ['age','season','skiarea_name','india']:
|
||||
data[c] = data[c].fillna('None')
|
||||
if use_small:
|
||||
data[c] = pd.Categorical( encoders['small'][c].transform(data[c]), categories=list(range(len(encoders['small'][c].classes_))), ordered=False)
|
||||
else:
|
||||
data[c] = pd.Categorical( encoders['normal'][c].transform(data[c]), categories=list(range(len(encoders['normal'][c].classes_))), ordered=False)
|
||||
|
||||
bst_FS = xgb.Booster()
|
||||
bst_FS.load_model("app/model.json")
|
||||
|
||||
dtest_FS = xgb.DMatrix(data[bst_FS.feature_names],enable_categorical=True)
|
||||
preds = bst_FS.predict(dtest_FS)
|
||||
|
||||
return preds[0]*100
|
||||
31
service/requirements.txt
Normal file
31
service/requirements.txt
Normal file
@@ -0,0 +1,31 @@
|
||||
annotated-types==0.6.0
|
||||
anyio==4.3.0
|
||||
click==8.1.7
|
||||
fastapi==0.110.0
|
||||
h11==0.14.0
|
||||
httptools==0.6.1
|
||||
idna==3.6
|
||||
joblib==1.3.2
|
||||
numpy==1.26.4
|
||||
pandas==2.2.1
|
||||
pydantic==2.6.4
|
||||
pydantic_core==2.16.3
|
||||
python-dateutil==2.9.0.post0
|
||||
python-dotenv==1.0.1
|
||||
pytz==2024.1
|
||||
PyYAML==6.0.1
|
||||
scikit-learn==1.4.1.post1
|
||||
scipy==1.12.0
|
||||
setuptools==68.2.2
|
||||
six==1.16.0
|
||||
sniffio==1.3.1
|
||||
starlette==0.36.3
|
||||
threadpoolctl==3.3.0
|
||||
typing_extensions==4.10.0
|
||||
tzdata==2024.1
|
||||
uvicorn==0.28.0
|
||||
uvloop==0.19.0
|
||||
watchfiles==0.21.0
|
||||
websockets==12.0
|
||||
wheel==0.41.2
|
||||
xgboost==2.0.3
|
||||
Reference in New Issue
Block a user