name: MLflow Patterns description: ML experiment tracking, model registry, and deployment with MLflow for reproducible machine learning workflows.
MLflow Patterns
Overview
MLflow เป็น open-source platform สำหรับ managing ML lifecycle ครอบคลุม experiment tracking, model packaging, model registry, และ deployment ช่วยให้ทีม data science ทำงานร่วมกันและ deploy models ได้อย่าง reproducible
Why This Matters
- Reproducibility: Track experiments และ reproduce results
- Collaboration: Share experiments และ models across team
- Deployment: Package และ deploy models consistently
- Governance: Model versioning และ approval workflow
Core Concepts
1. Experiment Tracking
import mlflow
from mlflow.tracking import MlflowClient
# Set tracking URI
mlflow.set_tracking_uri("http://mlflow-server:5000")
mlflow.set_experiment("customer-churn-prediction")
# Start run with auto-logging
mlflow.sklearn.autolog()
with mlflow.start_run(run_name="xgboost-v1") as run:
# Log parameters
mlflow.log_params({
"learning_rate": 0.1,
"max_depth": 6,
"n_estimators": 100,
"subsample": 0.8,
})
# Train model
model = XGBClassifier(
learning_rate=0.1,
max_depth=6,
n_estimators=100,
subsample=0.8,
)
model.fit(X_train, y_train)
# Log metrics
y_pred = model.predict(X_test)
mlflow.log_metrics({
"accuracy": accuracy_score(y_test, y_pred),
"precision": precision_score(y_test, y_pred),
"recall": recall_score(y_test, y_pred),
"f1": f1_score(y_test, y_pred),
"auc_roc": roc_auc_score(y_test, model.predict_proba(X_test)[:, 1]),
})
# Log artifacts
mlflow.log_artifact("feature_importance.png")
mlflow.log_artifact("confusion_matrix.png")
# Log model
mlflow.sklearn.log_model(
model,
artifact_path="model",
registered_model_name="churn-prediction-model",
)
# Log dataset info
mlflow.log_input(
mlflow.data.from_pandas(X_train, source="s3://data/train.parquet"),
context="training"
)
print(f"Run ID: {run.info.run_id}")
2. Custom Model Wrapper
import mlflow.pyfunc
class ChurnModelWrapper(mlflow.pyfunc.PythonModel):
"""Custom model wrapper with preprocessing"""
def load_context(self, context):
"""Load model and artifacts"""
import joblib
self.model = joblib.load(context.artifacts["model"])
self.preprocessor = joblib.load(context.artifacts["preprocessor"])
self.feature_names = context.artifacts["feature_names"]
def predict(self, context, model_input):
"""Predict with preprocessing"""
# Validate input
if not all(col in model_input.columns for col in self.feature_names):
raise ValueError(f"Missing required features: {self.feature_names}")
# Preprocess
processed = self.preprocessor.transform(model_input[self.feature_names])
# Predict with probability
predictions = self.model.predict_proba(processed)[:, 1]
return pd.DataFrame({
"churn_probability": predictions,
"churn_prediction": (predictions > 0.5).astype(int),
})
# Log custom model
with mlflow.start_run():
artifacts = {
"model": "model.joblib",
"preprocessor": "preprocessor.joblib",
"feature_names": "features.json",
}
mlflow.pyfunc.log_model(
artifact_path="model",
python_model=ChurnModelWrapper(),
artifacts=artifacts,
conda_env={
"dependencies": [
"python=3.10",
"scikit-learn=1.3.0",
"xgboost=2.0.0",
"pandas=2.0.0",
]
},
signature=mlflow.models.infer_signature(X_test, predictions),
input_example=X_test.head(5),
)
3. Model Registry
from mlflow.tracking import MlflowClient
client = MlflowClient()
# Register model from run
model_uri = f"runs:/{run_id}/model"
model_version = mlflow.register_model(model_uri, "churn-prediction-model")
# Add description and tags
client.update_model_version(
name="churn-prediction-model",
version=model_version.version,
description="XGBoost model trained on Q4 2024 data"
)
client.set_model_version_tag(
name="churn-prediction-model",
version=model_version.version,
key="validation_status",
value="pending"
)
# Transition to staging (after validation)
client.transition_model_version_stage(
name="churn-prediction-model",
version=model_version.version,
stage="Staging",
archive_existing_versions=False
)
# Promote to production (after approval)
client.transition_model_version_stage(
name="churn-prediction-model",
version=model_version.version,
stage="Production",
archive_existing_versions=True # Archive old production version
)
# Load production model
model = mlflow.pyfunc.load_model("models:/churn-prediction-model/Production")
predictions = model.predict(new_data)
4. Model Validation Pipeline
# validation/validate_model.py
import mlflow
from mlflow.tracking import MlflowClient
def validate_model(model_name: str, version: str) -> bool:
"""Validate model before promotion"""
client = MlflowClient()
model_uri = f"models:/{model_name}/{version}"
# Load model
model = mlflow.pyfunc.load_model(model_uri)
# Load validation dataset
val_data = pd.read_parquet("s3://data/validation.parquet")
X_val, y_val = val_data.drop("target", axis=1), val_data["target"]
# Run predictions
predictions = model.predict(X_val)
# Calculate metrics
metrics = {
"val_accuracy": accuracy_score(y_val, predictions["churn_prediction"]),
"val_auc": roc_auc_score(y_val, predictions["churn_probability"]),
}
# Get production model metrics (if exists)
try:
prod_model = mlflow.pyfunc.load_model(f"models:/{model_name}/Production")
prod_predictions = prod_model.predict(X_val)
prod_metrics = {
"prod_accuracy": accuracy_score(y_val, prod_predictions["churn_prediction"]),
"prod_auc": roc_auc_score(y_val, prod_predictions["churn_probability"]),
}
except:
prod_metrics = {"prod_accuracy": 0, "prod_auc": 0}
# Validation rules
validations = [
("accuracy_threshold", metrics["val_accuracy"] >= 0.85),
("auc_threshold", metrics["val_auc"] >= 0.80),
("accuracy_improvement", metrics["val_accuracy"] >= prod_metrics["prod_accuracy"]),
("auc_improvement", metrics["val_auc"] >= prod_metrics["prod_auc"] - 0.01), # Allow 1% drop
]
# Log validation results
with mlflow.start_run(run_name=f"validation-{model_name}-v{version}"):
mlflow.log_metrics(metrics)
mlflow.log_metrics(prod_metrics)
for name, passed in validations:
mlflow.log_metric(f"validation_{name}", int(passed))
# Update model tags
all_passed = all(passed for _, passed in validations)
client.set_model_version_tag(
name=model_name,
version=version,
key="validation_status",
value="passed" if all_passed else "failed"
)
return all_passed
5. Model Serving
# serve/model_server.py
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import mlflow
app = FastAPI()
# Load model at startup
MODEL_NAME = "churn-prediction-model"
MODEL_STAGE = "Production"
model = None
@app.on_event("startup")
async def load_model():
global model
model = mlflow.pyfunc.load_model(f"models:/{MODEL_NAME}/{MODEL_STAGE}")
class PredictionRequest(BaseModel):
features: dict
class PredictionResponse(BaseModel):
churn_probability: float
churn_prediction: int
model_version: str
@app.post("/predict", response_model=PredictionResponse)
async def predict(request: PredictionRequest):
try:
input_df = pd.DataFrame([request.features])
predictions = model.predict(input_df)
return PredictionResponse(
churn_probability=float(predictions["churn_probability"].iloc[0]),
churn_prediction=int(predictions["churn_prediction"].iloc[0]),
model_version=model.metadata.run_id,
)
except Exception as e:
raise HTTPException(status_code=400, detail=str(e))
@app.get("/health")
async def health():
return {"status": "healthy", "model_loaded": model is not None}
# Or use MLflow's built-in serving
# mlflow models serve -m "models:/churn-prediction-model/Production" -p 5001
Quick Start
-
Install MLflow:
pip install mlflow -
Start tracking server:
mlflow server --backend-store-uri sqlite:///mlflow.db \ --default-artifact-root s3://mlflow-artifacts \ --host 0.0.0.0 -
Set tracking URI in code:
mlflow.set_tracking_uri("http://localhost:5000") -
Run experiment:
with mlflow.start_run(): mlflow.log_param("param", value) mlflow.log_metric("metric", value) mlflow.sklearn.log_model(model, "model") -
View in UI: Open http://localhost:5000
Production Checklist
- Tracking server with persistent backend
- Artifact storage (S3/GCS/Azure Blob)
- Authentication enabled
- Model signature defined
- Input examples logged
- Conda/pip environment specified
- Validation pipeline configured
- Model approval workflow
- Monitoring for model drift
Anti-patterns
- No Experiment Naming: Use meaningful experiment/run names
- Skipping Signatures: Always define model signatures
- Manual Promotion: Use validation pipeline for stage transitions
- Missing Environment: Always specify dependencies
Integration Points
- Storage: S3, GCS, Azure Blob, HDFS
- Databases: PostgreSQL, MySQL for backend store
- Orchestration: Airflow, Prefect, Dagster
- Serving: SageMaker, Kubernetes, Azure ML