GadaaLabs
Python Mastery — From Zero to AI Engineering
Lesson 17

Production Python — FastAPI, Packaging & Profiling

30 min

Project Structure

Two dominant layouts for Python projects:

Flat layout — simple, good for applications:

my_project/
├── app/
│   ├── __init__.py
│   ├── main.py
│   ├── models.py
│   └── api/
├── tests/
├── pyproject.toml
└── README.md

src layout — recommended for libraries (prevents accidental imports from the project root):

my_library/
├── src/
│   └── my_library/
│       ├── __init__.py
│       ├── core.py
│       └── utils.py
├── tests/
├── pyproject.toml
└── README.md

Use src layout when publishing to PyPI. Use flat layout for web applications and scripts. The key difference: with src layout, you can only import your package if it's installed, which catches packaging bugs early.

pyproject.toml — The Modern Standard

pyproject.toml replaces setup.py, setup.cfg, requirements.txt, and MANIFEST.in:

toml
[build-system]
requires = ["hatchling"]
build-backend = "hatchling.build"

[project]
name = "ml-serving"
version = "0.1.0"
description = "FastAPI-based ML model serving"
readme = "README.md"
requires-python = ">=3.11"
license = {text = "MIT"}
dependencies = [
    "fastapi>=0.111.0",
    "uvicorn[standard]>=0.29.0",
    "pydantic>=2.7.0",
    "scikit-learn>=1.4.0",
    "numpy>=1.26.0",
]

[project.optional-dependencies]
dev = [
    "pytest>=8.0.0",
    "httpx>=0.27.0",    # For FastAPI TestClient
    "ruff>=0.4.0",
    "mypy>=1.10.0",
]

[project.scripts]
serve = "ml_serving.main:start"

[tool.ruff]
line-length = 88
select = ["E", "F", "I", "N", "UP"]

[tool.mypy]
strict = true
python_version = "3.11"

[tool.pytest.ini_options]
testpaths = ["tests"]
asyncio_mode = "auto"

Install in development mode: pip install -e ".[dev]". The -e flag installs a symlink to your source, so edits take effect immediately without reinstalling.

FastAPI — Async ML Serving

FastAPI generates OpenAPI docs automatically, validates requests via Pydantic, and handles async operations natively:

python
# app/main.py
from contextlib import asynccontextmanager
from fastapi import FastAPI, HTTPException, Depends
from pydantic import BaseModel, Field, field_validator
import numpy as np
import pickle
import logging

logger = logging.getLogger(__name__)

# ── Lifespan: load model on startup, clean up on shutdown ─────────────────────
@asynccontextmanager
async def lifespan(app: FastAPI):
    # Startup
    logger.info("Loading model...")
    app.state.model = pickle.load(open("model.pkl", "rb"))
    app.state.scaler = pickle.load(open("scaler.pkl", "rb"))
    logger.info("Model loaded successfully")
    yield
    # Shutdown
    logger.info("Shutting down")

app = FastAPI(
    title="ML Model API",
    version="1.0.0",
    lifespan=lifespan,
)

# ── Pydantic models define the API contract ────────────────────────────────────
class PredictRequest(BaseModel):
    tenure: int = Field(..., ge=1, le=72, description="Customer tenure in months")
    monthly_charges: float = Field(..., gt=0, le=500)
    contract_type: str = Field(..., pattern="^(month-to-month|one_year|two_year)$")
    num_products: int = Field(..., ge=1, le=10)

    @field_validator("monthly_charges")
    @classmethod
    def validate_charges(cls, v):
        if v < 20:
            raise ValueError("Monthly charges seem too low")
        return round(v, 2)

class PredictResponse(BaseModel):
    churn_probability: float
    churn_prediction: bool
    confidence: str
    model_version: str = "1.0.0"

class HealthResponse(BaseModel):
    status: str
    model_loaded: bool

# ── Endpoints ─────────────────────────────────────────────────────────────────
@app.get("/health", response_model=HealthResponse)
async def health():
    return {
        "status": "healthy",
        "model_loaded": hasattr(app.state, "model"),
    }

@app.post("/predict", response_model=PredictResponse)
async def predict(request: PredictRequest):
    try:
        features = np.array([[
            request.tenure,
            request.monthly_charges,
            1 if request.contract_type == "month-to-month" else 0,
            request.num_products,
        ]])
        features_scaled = app.state.scaler.transform(features)
        proba = app.state.model.predict_proba(features_scaled)[0, 1]
        confidence = "high" if abs(proba - 0.5) > 0.3 else "medium" if abs(proba - 0.5) > 0.1 else "low"

        logger.info(f"Prediction: proba={proba:.3f}, tenure={request.tenure}")
        return PredictResponse(
            churn_probability=round(float(proba), 4),
            churn_prediction=proba > 0.5,
            confidence=confidence,
        )
    except Exception as e:
        logger.error(f"Prediction failed: {e}")
        raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")

@app.get("/model/info")
async def model_info():
    return {
        "model_type": type(app.state.model).__name__,
        "features": ["tenure", "monthly_charges", "contract_type", "num_products"],
        "output": "churn_probability [0, 1]",
    }

Pydantic v2

Pydantic v2 (released 2023) is 5-50x faster than v1 due to a Rust core:

python
from pydantic import BaseModel, Field, field_validator, model_validator
from typing import Optional
from datetime import datetime

class MLExperiment(BaseModel):
    model_config = {"str_strip_whitespace": True, "frozen": True}

    name: str = Field(..., min_length=3, max_length=100)
    accuracy: float = Field(..., ge=0.0, le=1.0)
    parameters: dict[str, float | int | str]
    tags: list[str] = []
    created_at: datetime = Field(default_factory=datetime.utcnow)
    parent_run_id: Optional[str] = None

    @field_validator("name")
    @classmethod
    def no_spaces_in_name(cls, v: str) -> str:
        return v.replace(" ", "_").lower()

    @model_validator(mode="after")
    def check_high_accuracy_has_tag(self) -> "MLExperiment":
        if self.accuracy > 0.95 and "validated" not in self.tags:
            raise ValueError("High-accuracy models must include 'validated' tag")
        return self

# Serialization
exp = MLExperiment(name="churn model v2", accuracy=0.88, parameters={"n_estimators": 100})
print(exp.model_dump())
print(exp.model_dump_json(indent=2))

# Parsing from dict/JSON
data = {"name": "test", "accuracy": 0.75, "parameters": {"lr": 0.01}}
exp2 = MLExperiment.model_validate(data)

Structured Logging

Production logging should be machine-parseable. Never use print() in production:

python
import logging
import json
import sys
import uuid
from contextvars import ContextVar

# Correlation ID for tracing requests across logs
correlation_id: ContextVar[str] = ContextVar("correlation_id", default="")

class JSONFormatter(logging.Formatter):
    def format(self, record: logging.LogRecord) -> str:
        log_obj = {
            "timestamp":      self.formatTime(record),
            "level":          record.levelname,
            "logger":         record.name,
            "message":        record.getMessage(),
            "module":         record.module,
            "correlation_id": correlation_id.get(""),
        }
        if record.exc_info:
            log_obj["exception"] = self.formatException(record.exc_info)
        return json.dumps(log_obj)

def setup_logging(level: str = "INFO") -> None:
    handler = logging.StreamHandler(sys.stdout)
    handler.setFormatter(JSONFormatter())
    root = logging.getLogger()
    root.setLevel(getattr(logging, level.upper()))
    root.addHandler(handler)

# FastAPI middleware to inject correlation IDs
from fastapi import Request
from starlette.middleware.base import BaseHTTPMiddleware

class CorrelationIDMiddleware(BaseHTTPMiddleware):
    async def dispatch(self, request: Request, call_next):
        cid = request.headers.get("X-Correlation-ID", str(uuid.uuid4())[:8])
        token = correlation_id.set(cid)
        try:
            response = await call_next(request)
            response.headers["X-Correlation-ID"] = cid
            return response
        finally:
            correlation_id.reset(token)

Profiling

Profiling: Memory and CPU
Click Run to execute — Python runs in your browser via WebAssembly
Performance Profiler — 5 Implementations
Click Run to execute — Python runs in your browser via WebAssembly

Concurrency in FastAPI

python
from fastapi import FastAPI, BackgroundTasks
import asyncio

app = FastAPI()

# Async endpoints: release event loop during I/O waits
@app.post("/predict/async")
async def predict_async(request: PredictRequest):
    # Non-blocking: event loop handles other requests while waiting
    result = await run_prediction_in_thread(request)
    return result

# CPU-bound: offload to thread pool to avoid blocking event loop
from fastapi.concurrency import run_in_threadpool

@app.post("/predict/cpu")
async def predict_cpu(request: PredictRequest):
    # run_in_threadpool runs blocking code in a thread pool
    result = await run_in_threadpool(blocking_predict, request)
    return result

# Background tasks: return immediately, work continues after response
@app.post("/train")
async def trigger_training(background_tasks: BackgroundTasks):
    background_tasks.add_task(retrain_model_job, dataset_path="data/new.csv")
    return {"status": "training started"}

async def retrain_model_job(dataset_path: str):
    """Runs after the response is sent."""
    await asyncio.sleep(0)   # Yield once to allow response to flush
    # ... long training job ...

Testing a FastAPI App

python
# tests/test_api.py
import pytest
from fastapi.testclient import TestClient
from unittest.mock import MagicMock, patch
from app.main import app

@pytest.fixture
def client():
    with TestClient(app) as c:
        yield c

@pytest.fixture
def mock_model():
    model = MagicMock()
    model.predict_proba.return_value = [[0.3, 0.7]]   # 70% churn probability
    return model

def test_health(client):
    response = client.get("/health")
    assert response.status_code == 200
    assert response.json()["status"] == "healthy"

def test_predict_valid(client, mock_model):
    with patch.object(app.state, "model", mock_model):
        response = client.post("/predict", json={
            "tenure": 12,
            "monthly_charges": 65.0,
            "contract_type": "month-to-month",
            "num_products": 2,
        })
    assert response.status_code == 200
    data = response.json()
    assert 0 <= data["churn_probability"] <= 1
    assert data["churn_prediction"] is True   # 70% > 50%

def test_predict_invalid_contract(client):
    response = client.post("/predict", json={
        "tenure": 12,
        "monthly_charges": 65.0,
        "contract_type": "invalid",   # Should fail validation
        "num_products": 2,
    })
    assert response.status_code == 422   # Pydantic validation error

Docker for Python

dockerfile
# Dockerfile
FROM python:3.11-slim AS base
WORKDIR /app

# Dependencies layer (cached separately from code)
COPY pyproject.toml .
RUN pip install --no-cache-dir -e ".[prod]"

# Code layer
COPY src/ src/
COPY app/ app/

# Non-root user for security
RUN useradd --create-home appuser
USER appuser

EXPOSE 8000
CMD ["uvicorn", "app.main:app", "--host", "0.0.0.0", "--port", "8000", "--workers", "4"]
dockerfile
# Multi-stage build to reduce image size
FROM python:3.11-slim AS builder
WORKDIR /build
COPY pyproject.toml .
RUN pip install --no-cache-dir build && python -m build

FROM python:3.11-slim AS runtime
WORKDIR /app
COPY --from=builder /build/dist/*.whl .
RUN pip install --no-cache-dir *.whl && rm *.whl
CMD ["uvicorn", "ml_serving.main:app", "--host", "0.0.0.0", "--port", "8000"]

Layer caching matters: put slowly-changing content (dependencies) before rapidly-changing content (code). The COPY pyproject.toml + RUN pip install layer is cached unless pyproject.toml changes.

Key Takeaways

  • Use src/ layout for libraries, flat layout for applications — this prevents silent import bugs in packaging
  • pyproject.toml is the single source of truth for project metadata, dependencies, and tool configuration
  • FastAPI's lifespan context manager is the correct place to load ML models — not module-level globals
  • Pydantic v2 validators (@field_validator, @model_validator) are the first line of defense against bad input
  • Never use print() in production — structured JSON logs are queryable by log aggregation systems
  • Profile with cProfile for CPU, tracemalloc for memory — never optimize without measuring first
  • NumPy vectorization can be 100x faster than Python loops; use np.dot() for sum-of-products patterns
  • Docker layer order is a performance decision: COPY requirements before COPY source maximizes cache hits