43. 蘑菇分类模型部署和推理#
43.1. 介绍#
本次挑战重点在于对 scikit-learn 训练模型的保存和部署,你需要按要求完成模型的线上部署并使之能对新的数据进行推理。
43.2. 知识点#
毒蘑菇分类预测
模型部署和推理
挑战选择了 UCI Machine Learning 提供的 蘑菇分类数据集,其采集了 8124 个蘑菇的样本,包含这些蘑菇的各类物理特性,例如气味、尺寸、颜色等。最终,这些样本被标记为 2 类:可食用和有毒。
挑战采样了原数据集中的 8000 条数据,你可以直接预览这些样本:
# 下载数据集
wget -nc https://cdn.aibydoing.com/aibydoing/files/mushrooms.csv
import pandas as pd
# 挑战所需训练数据集,复制链接粘贴到浏览器即可下载
df = pd.read_csv("mushrooms.csv")
df.head()
class | cap-shape | cap-surface | cap-color | bruises | odor | gill-attachment | gill-spacing | gill-size | gill-color | ... | stalk-surface-below-ring | stalk-color-above-ring | stalk-color-below-ring | veil-type | veil-color | ring-number | ring-type | spore-print-color | population | habitat | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
0 | p | x | s | w | f | c | f | c | n | p | ... | s | w | w | p | w | o | p | n | s | d |
1 | p | x | s | e | f | s | f | c | n | b | ... | k | w | w | p | w | o | e | w | v | p |
2 | p | k | s | e | f | y | f | c | n | b | ... | s | p | p | p | w | o | e | w | v | d |
3 | p | f | f | g | f | f | f | c | b | p | ... | k | b | n | p | w | o | l | h | y | g |
4 | e | f | f | n | f | n | f | w | b | h | ... | s | w | w | p | w | o | e | k | s | g |
5 rows × 23 columns
其中,class=e
表示可食用,class=p
表示有毒。其余列为特征数据。
接下来,你需要利用该数据集训练一个毒蘑菇分类器,并使用 Flask 将保存好的模型部署为 API 接口,可通过 HTTP 请求的方式获得推理结果。
挑战:参考前序实验中泰坦尼克号生存预测模型,训练毒蘑菇分类器,并完成模型部署。
规定:数据特征处理,算法选择等方式不定,可以自由发挥,只需要满足最终获得蘑菇是否有毒的结果即可。
挑战测试说明
本次挑战推荐在线下完成,同时你需要使用 scikit-learn
训练并保存模型,最终使用 Flask 完成 Web 应用构建。启动 Flask
后,可在本地向
localhost
发起 POST
请求获得推理结果。测试时推荐使用原数据集中的样本,传入数据需为
JSON 类型。
测试示例代码如下:
wget -nc https://cdn.aibydoing.com/aibydoing/files/mushrooms_test.csv
import json
import requests
import pandas as pd
df = pd.read_csv("mushrooms_test.csv") # 读取测试数据集
sample_data = df.sample(1).to_json() # 从原数据中随机取 1 条用于测试推理,并转换成 JSON 样式
sample_json = json.loads(sample_data) # 将 Pandas 转换的 JSON 样式数据处理成 JSON 类型
requests.post(url="http://localhost:5000", json=sample_json).content # 建立 POST 请求,并发送数据请求
期望输出
请求返回的内容和样式自定义,但至少需要返回推理类别结果。
参考答案 Exercise 43.1
# 加载数据集
import pandas as pd
import warnings
warnings.filterwarnings('ignore')
df = pd.read_csv("mushrooms.csv")
df.tail()
# 模型训练和保存
from sklearn.ensemble import RandomForestClassifier
from sklearn.model_selection import cross_val_score
import joblib
X = pd.get_dummies(df.iloc[:, 1:]) # 读取特征并独热编码
y = df['class'] # 目标值
model = RandomForestClassifier() # 随机森林
print(cross_val_score(model, X, y, cv=5).mean()) # 交叉验证结果
model.fit(X, y) # 训练模型
joblib.dump(model, "mushrooms.pkl") # 保存模型
print("model saved.")
# 构建 Flask Web
%%writefile predict.py
# 将此单元格代码写入 predict.py 文件方便后面执行
import joblib
import pandas as pd
from flask import Flask, request, jsonify
app = Flask(__name__)
@app.route("/", methods=["POST"]) # 请求方法为 POST
def inference():
query_df = pd.DataFrame(request.json) # 将 JSON 变为 DataFrame
df = pd.read_csv("mushrooms.csv") # 读取数据
X = pd.get_dummies(df.iloc[:, 1:]) # 读取特征并独热编码
query = pd.get_dummies(query_df).reindex(columns=X.columns, fill_value=0) # 将请求数据 DataFrame 处理成独热编码样式
clf = joblib.load('mushrooms.pkl') # 加载模型
prediction = clf.predict(query) # 模型推理
return jsonify({"prediction": list(prediction)}) # 返回推理结果
# Notebook 中必须以子进程才能正常启动 Flask
import time
import subprocess as sp
# 启动子进程执行 Flask app
server = sp.Popen("FLASK_APP=predict.py flask run", shell=True)
time.sleep(5) # 等待 5 秒保证 Flask 启动成功
server
import json
# 从测试数据中取 1 条用于测试推理
df_test = pd.read_csv("mushrooms_test.csv")
sample_data = df.sample(1).to_json(orient='records')
sample_json = json.loads(sample_data)
sample_json
import requests
requests.post(url="http://localhost:5000", json=sample_json).content # 建立 POST 请求,并发送数据请求
server.terminate() # 结束子进程,关闭端口占用
○ 欢迎分享本文链接到你的社交账号、博客、论坛等。更多的外链会增加搜索引擎对本站收录的权重,从而让更多人看到这些内容。