Skip to content

模型部署

模型部署是机器学习项目的最后一步,也是最关键的一步。本章将详细介绍如何将TensorFlow模型部署到生产环境中,包括各种部署方式和最佳实践。

部署概述

部署方式对比

部署方式适用场景优点缺点
TensorFlow Serving高性能服务高吞吐量、低延迟配置复杂
Flask/FastAPI快速原型简单易用、灵活性能有限
TensorFlow Lite移动端/边缘设备模型小、推理快功能受限
TensorFlow.js浏览器/Node.js客户端推理模型大小限制
Docker容器云部署环境一致性资源开销
Kubernetes大规模部署自动扩缩容复杂度高
python
import tensorflow as tf
from tensorflow import keras
import numpy as np
import json
import os
from pathlib import Path
import requests
import time

print(f"TensorFlow版本: {tf.__version__}")

TensorFlow Serving

模型准备和导出

python
def create_sample_model():
    """
    创建示例模型用于部署
    """
    model = keras.Sequential([
        keras.layers.Dense(128, activation='relu', input_shape=(784,)),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(64, activation='relu'),
        keras.layers.Dropout(0.2),
        keras.layers.Dense(10, activation='softmax')
    ])
    
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

def export_model_for_serving(model, export_path, version=1):
    """
    导出模型用于TensorFlow Serving
    """
    # 创建版本目录
    versioned_path = os.path.join(export_path, str(version))
    
    # 保存模型
    tf.saved_model.save(model, versioned_path)
    
    print(f"模型已导出到: {versioned_path}")
    
    # 验证导出的模型
    loaded_model = tf.saved_model.load(versioned_path)
    print("模型签名:")
    print(list(loaded_model.signatures.keys()))
    
    return versioned_path

def create_model_with_preprocessing():
    """
    创建包含预处理的模型
    """
    # 输入层
    inputs = keras.layers.Input(shape=(28, 28), name='image')
    
    # 预处理层
    x = keras.layers.Reshape((784,))(inputs)
    x = keras.layers.Lambda(lambda x: tf.cast(x, tf.float32) / 255.0)(x)
    
    # 主模型
    x = keras.layers.Dense(128, activation='relu')(x)
    x = keras.layers.Dropout(0.2)(x)
    x = keras.layers.Dense(64, activation='relu')(x)
    x = keras.layers.Dropout(0.2)(x)
    outputs = keras.layers.Dense(10, activation='softmax', name='predictions')(x)
    
    model = keras.Model(inputs=inputs, outputs=outputs)
    
    model.compile(
        optimizer='adam',
        loss='sparse_categorical_crossentropy',
        metrics=['accuracy']
    )
    
    return model

# 创建和导出模型
sample_model = create_sample_model()
model_with_preprocessing = create_model_with_preprocessing()

# 导出模型
export_path = './models/mnist_classifier'
export_model_for_serving(model_with_preprocessing, export_path, version=1)

TensorFlow Serving配置

python
def create_serving_config(model_name, model_base_path, model_platform='tensorflow'):
    """
    创建TensorFlow Serving配置文件
    """
    config = {
        "model_config_list": [
            {
                "name": model_name,
                "base_path": model_base_path,
                "model_platform": model_platform,
                "model_version_policy": {
                    "latest": {
                        "num_versions": 2
                    }
                }
            }
        ]
    }
    
    config_path = f"{model_name}_config.json"
    with open(config_path, 'w') as f:
        json.dump(config, f, indent=2)
    
    print(f"配置文件已保存到: {config_path}")
    return config_path

def create_docker_compose_serving():
    """
    创建Docker Compose文件用于TensorFlow Serving
    """
    docker_compose_content = """
version: '3.8'

services:
  tensorflow-serving:
    image: tensorflow/serving:latest
    ports:
      - "8501:8501"  # REST API
      - "8500:8500"  # gRPC API
    volumes:
      - ./models:/models
    environment:
      - MODEL_NAME=mnist_classifier
      - MODEL_BASE_PATH=/models/mnist_classifier
    command: >
      tensorflow_model_server
      --rest_api_port=8501
      --model_name=mnist_classifier
      --model_base_path=/models/mnist_classifier
      --monitoring_config_file=""
"""
    
    with open('docker-compose-serving.yml', 'w') as f:
        f.write(docker_compose_content)
    
    print("Docker Compose文件已创建: docker-compose-serving.yml")

# 创建配置文件
create_serving_config('mnist_classifier', '/models/mnist_classifier')
create_docker_compose_serving()

客户端调用

python
class TensorFlowServingClient:
    """
    TensorFlow Serving客户端
    """
    def __init__(self, server_url='http://localhost:8501', model_name='mnist_classifier'):
        self.server_url = server_url
        self.model_name = model_name
        self.predict_url = f"{server_url}/v1/models/{model_name}:predict"
        self.metadata_url = f"{server_url}/v1/models/{model_name}/metadata"
    
    def get_model_metadata(self):
        """
        获取模型元数据
        """
        try:
            response = requests.get(self.metadata_url)
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            print(f"获取模型元数据失败: {e}")
            return None
    
    def predict(self, instances):
        """
        进行预测
        """
        data = {
            "instances": instances
        }
        
        try:
            response = requests.post(
                self.predict_url,
                json=data,
                headers={'Content-Type': 'application/json'}
            )
            response.raise_for_status()
            return response.json()
        except requests.exceptions.RequestException as e:
            print(f"预测请求失败: {e}")
            return None
    
    def predict_batch(self, instances, batch_size=32):
        """
        批量预测
        """
        results = []
        
        for i in range(0, len(instances), batch_size):
            batch = instances[i:i + batch_size]
            result = self.predict(batch)
            
            if result and 'predictions' in result:
                results.extend(result['predictions'])
            else:
                print(f"批次 {i//batch_size + 1} 预测失败")
        
        return results
    
    def benchmark(self, instances, num_requests=100):
        """
        性能基准测试
        """
        print(f"开始性能测试,发送 {num_requests} 个请求...")
        
        start_time = time.time()
        successful_requests = 0
        
        for i in range(num_requests):
            result = self.predict(instances)
            if result:
                successful_requests += 1
            
            if (i + 1) % 10 == 0:
                print(f"已完成 {i + 1}/{num_requests} 个请求")
        
        end_time = time.time()
        total_time = end_time - start_time
        
        print(f"\n性能测试结果:")
        print(f"总请求数: {num_requests}")
        print(f"成功请求数: {successful_requests}")
        print(f"总耗时: {total_time:.2f} 秒")
        print(f"平均延迟: {total_time/num_requests*1000:.2f} ms")
        print(f"QPS: {successful_requests/total_time:.2f}")

# 使用示例
def test_serving_client():
    """
    测试TensorFlow Serving客户端
    """
    client = TensorFlowServingClient()
    
    # 创建测试数据
    test_data = np.random.randint(0, 255, (5, 28, 28)).tolist()
    
    # 获取模型元数据
    metadata = client.get_model_metadata()
    if metadata:
        print("模型元数据:")
        print(json.dumps(metadata, indent=2))
    
    # 进行预测
    result = client.predict(test_data)
    if result:
        print("\n预测结果:")
        print(json.dumps(result, indent=2))
    
    # 性能测试
    single_instance = [test_data[0]]
    client.benchmark(single_instance, num_requests=50)

# 注意:需要先启动TensorFlow Serving服务
# test_serving_client()

Flask/FastAPI部署

Flask部署

python
def create_flask_app(model_path):
    """
    创建Flask应用
    """
    from flask import Flask, request, jsonify
    import pickle
    
    app = Flask(__name__)
    
    # 加载模型
    model = keras.models.load_model(model_path)
    
    @app.route('/health', methods=['GET'])
    def health():
        return jsonify({'status': 'healthy', 'model_loaded': True})
    
    @app.route('/predict', methods=['POST'])
    def predict():
        try:
            data = request.get_json()
            
            if 'instances' not in data:
                return jsonify({'error': '缺少instances字段'}), 400
            
            instances = np.array(data['instances'])
            
            # 预测
            predictions = model.predict(instances)
            
            # 转换为列表以便JSON序列化
            predictions_list = predictions.tolist()
            
            return jsonify({
                'predictions': predictions_list,
                'model_version': '1.0'
            })
            
        except Exception as e:
            return jsonify({'error': str(e)}), 500
    
    @app.route('/predict_class', methods=['POST'])
    def predict_class():
        try:
            data = request.get_json()
            instances = np.array(data['instances'])
            
            # 预测
            predictions = model.predict(instances)
            predicted_classes = np.argmax(predictions, axis=1)
            confidences = np.max(predictions, axis=1)
            
            results = []
            for i in range(len(predicted_classes)):
                results.append({
                    'predicted_class': int(predicted_classes[i]),
                    'confidence': float(confidences[i]),
                    'all_probabilities': predictions[i].tolist()
                })
            
            return jsonify({'results': results})
            
        except Exception as e:
            return jsonify({'error': str(e)}), 500
    
    return app

def create_flask_with_monitoring():
    """
    创建带监控的Flask应用
    """
    from flask import Flask, request, jsonify
    import time
    import logging
    from collections import defaultdict
    
    app = Flask(__name__)
    
    # 设置日志
    logging.basicConfig(level=logging.INFO)
    logger = logging.getLogger(__name__)
    
    # 监控指标
    metrics = {
        'request_count': defaultdict(int),
        'response_times': [],
        'error_count': 0
    }
    
    @app.before_request
    def before_request():
        request.start_time = time.time()
    
    @app.after_request
    def after_request(response):
        # 记录响应时间
        if hasattr(request, 'start_time'):
            response_time = time.time() - request.start_time
            metrics['response_times'].append(response_time)
            
            # 保持最近1000个响应时间
            if len(metrics['response_times']) > 1000:
                metrics['response_times'] = metrics['response_times'][-1000:]
        
        # 记录请求计数
        metrics['request_count'][request.endpoint] += 1
        
        return response
    
    @app.route('/metrics', methods=['GET'])
    def get_metrics():
        avg_response_time = np.mean(metrics['response_times']) if metrics['response_times'] else 0
        
        return jsonify({
            'request_count': dict(metrics['request_count']),
            'average_response_time': avg_response_time,
            'error_count': metrics['error_count'],
            'total_requests': sum(metrics['request_count'].values())
        })
    
    return app

# 创建Flask应用
# flask_app = create_flask_app('./models/mnist_classifier/1')
# flask_app.run(host='0.0.0.0', port=5000, debug=False)

FastAPI部署

python
def create_fastapi_app():
    """
    创建FastAPI应用
    """
    from fastapi import FastAPI, HTTPException
    from pydantic import BaseModel
    from typing import List
    import uvicorn
    
    app = FastAPI(title="ML Model API", version="1.0.0")
    
    # 请求模型
    class PredictionRequest(BaseModel):
        instances: List[List[float]]
    
    class PredictionResponse(BaseModel):
        predictions: List[List[float]]
        model_version: str
    
    class ClassificationResponse(BaseModel):
        predicted_class: int
        confidence: float
        all_probabilities: List[float]
    
    # 加载模型(在实际应用中应该在启动时加载)
    model = None
    
    @app.on_event("startup")
    async def startup_event():
        global model
        # model = keras.models.load_model('./models/mnist_classifier/1')
        print("模型加载完成")
    
    @app.get("/health")
    async def health():
        return {"status": "healthy", "model_loaded": model is not None}
    
    @app.post("/predict", response_model=PredictionResponse)
    async def predict(request: PredictionRequest):
        try:
            if model is None:
                raise HTTPException(status_code=503, detail="模型未加载")
            
            instances = np.array(request.instances)
            predictions = model.predict(instances)
            
            return PredictionResponse(
                predictions=predictions.tolist(),
                model_version="1.0"
            )
            
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))
    
    @app.post("/predict_class")
    async def predict_class(request: PredictionRequest):
        try:
            if model is None:
                raise HTTPException(status_code=503, detail="模型未加载")
            
            instances = np.array(request.instances)
            predictions = model.predict(instances)
            
            results = []
            for pred in predictions:
                predicted_class = int(np.argmax(pred))
                confidence = float(np.max(pred))
                
                results.append(ClassificationResponse(
                    predicted_class=predicted_class,
                    confidence=confidence,
                    all_probabilities=pred.tolist()
                ))
            
            return {"results": results}
            
        except Exception as e:
            raise HTTPException(status_code=500, detail=str(e))
    
    return app

# 创建FastAPI应用
# fastapi_app = create_fastapi_app()
# uvicorn.run(fastapi_app, host="0.0.0.0", port=8000)

TensorFlow Lite部署

模型转换

python
def convert_to_tflite(model, optimization=True, quantization=False):
    """
    将模型转换为TensorFlow Lite格式
    """
    converter = tf.lite.TFLiteConverter.from_keras_model(model)
    
    # 优化设置
    if optimization:
        converter.optimizations = [tf.lite.Optimize.DEFAULT]
    
    # 量化设置
    if quantization:
        converter.representative_dataset = representative_dataset_gen
        converter.target_spec.supported_ops = [tf.lite.OpsSet.TFLITE_BUILTINS_INT8]
        converter.inference_input_type = tf.int8
        converter.inference_output_type = tf.int8
    
    # 转换模型
    tflite_model = converter.convert()
    
    return tflite_model

def representative_dataset_gen():
    """
    代表性数据集生成器(用于量化)
    """
    # 生成代表性数据
    for _ in range(100):
        data = np.random.random((1, 28, 28)).astype(np.float32)
        yield [data]

def save_tflite_model(tflite_model, model_path):
    """
    保存TensorFlow Lite模型
    """
    with open(model_path, 'wb') as f:
        f.write(tflite_model)
    
    print(f"TensorFlow Lite模型已保存到: {model_path}")

def analyze_tflite_model(model_path):
    """
    分析TensorFlow Lite模型
    """
    interpreter = tf.lite.Interpreter(model_path=model_path)
    interpreter.allocate_tensors()
    
    # 获取输入和输出详情
    input_details = interpreter.get_input_details()
    output_details = interpreter.get_output_details()
    
    print("模型分析:")
    print(f"输入详情: {input_details}")
    print(f"输出详情: {output_details}")
    
    # 获取模型大小
    model_size = os.path.getsize(model_path)
    print(f"模型大小: {model_size / 1024:.2f} KB")
    
    return interpreter, input_details, output_details

# 转换模型
sample_model = create_sample_model()

# 标准转换
tflite_model = convert_to_tflite(sample_model, optimization=True)
save_tflite_model(tflite_model, 'model.tflite')

# 量化转换
tflite_quantized_model = convert_to_tflite(sample_model, optimization=True, quantization=True)
save_tflite_model(tflite_quantized_model, 'model_quantized.tflite')

# 分析模型
interpreter, input_details, output_details = analyze_tflite_model('model.tflite')

TensorFlow Lite推理

python
class TFLitePredictor:
    """
    TensorFlow Lite预测器
    """
    def __init__(self, model_path):
        self.interpreter = tf.lite.Interpreter(model_path=model_path)
        self.interpreter.allocate_tensors()
        
        self.input_details = self.interpreter.get_input_details()
        self.output_details = self.interpreter.get_output_details()
        
        print(f"模型加载完成: {model_path}")
        print(f"输入形状: {self.input_details[0]['shape']}")
        print(f"输出形状: {self.output_details[0]['shape']}")
    
    def predict(self, input_data):
        """
        进行预测
        """
        # 设置输入
        self.interpreter.set_tensor(self.input_details[0]['index'], input_data)
        
        # 运行推理
        self.interpreter.invoke()
        
        # 获取输出
        output_data = self.interpreter.get_tensor(self.output_details[0]['index'])
        
        return output_data
    
    def benchmark(self, input_data, num_runs=1000):
        """
        性能基准测试
        """
        print(f"开始TensorFlow Lite性能测试,运行 {num_runs} 次...")
        
        # 预热
        for _ in range(10):
            self.predict(input_data)
        
        # 测试
        start_time = time.time()
        for _ in range(num_runs):
            self.predict(input_data)
        end_time = time.time()
        
        total_time = end_time - start_time
        avg_time = total_time / num_runs
        
        print(f"总耗时: {total_time:.4f} 秒")
        print(f"平均推理时间: {avg_time*1000:.4f} ms")
        print(f"QPS: {num_runs/total_time:.2f}")

# 使用TensorFlow Lite预测器
tflite_predictor = TFLitePredictor('model.tflite')

# 创建测试数据
test_input = np.random.random((1, 784)).astype(np.float32)

# 进行预测
prediction = tflite_predictor.predict(test_input)
print(f"预测结果: {prediction}")

# 性能测试
tflite_predictor.benchmark(test_input, num_runs=1000)

Docker容器化部署

Dockerfile创建

python
def create_dockerfile():
    """
    创建Dockerfile
    """
    dockerfile_content = """
# 使用官方Python运行时作为基础镜像
FROM python:3.9-slim

# 设置工作目录
WORKDIR /app

# 复制requirements文件
COPY requirements.txt .

# 安装Python依赖
RUN pip install --no-cache-dir -r requirements.txt

# 复制应用代码
COPY . .

# 暴露端口
EXPOSE 5000

# 设置环境变量
ENV FLASK_APP=app.py
ENV FLASK_ENV=production

# 运行应用
CMD ["python", "app.py"]
"""
    
    with open('Dockerfile', 'w') as f:
        f.write(dockerfile_content)
    
    print("Dockerfile已创建")

def create_requirements_txt():
    """
    创建requirements.txt文件
    """
    requirements = """
tensorflow==2.13.0
flask==2.3.2
numpy==1.24.3
requests==2.31.0
gunicorn==21.2.0
"""
    
    with open('requirements.txt', 'w') as f:
        f.write(requirements.strip())
    
    print("requirements.txt已创建")

def create_docker_compose():
    """
    创建Docker Compose文件
    """
    docker_compose_content = """
version: '3.8'

services:
  ml-api:
    build: .
    ports:
      - "5000:5000"
    environment:
      - FLASK_ENV=production
    volumes:
      - ./models:/app/models
    restart: unless-stopped
    
  nginx:
    image: nginx:alpine
    ports:
      - "80:80"
    volumes:
      - ./nginx.conf:/etc/nginx/nginx.conf
    depends_on:
      - ml-api
    restart: unless-stopped
"""
    
    with open('docker-compose.yml', 'w') as f:
        f.write(docker_compose_content)
    
    print("docker-compose.yml已创建")

def create_nginx_config():
    """
    创建Nginx配置文件
    """
    nginx_config = """
events {
    worker_connections 1024;
}

http {
    upstream ml_api {
        server ml-api:5000;
    }
    
    server {
        listen 80;
        
        location / {
            proxy_pass http://ml_api;
            proxy_set_header Host $host;
            proxy_set_header X-Real-IP $remote_addr;
            proxy_set_header X-Forwarded-For $proxy_add_x_forwarded_for;
            proxy_set_header X-Forwarded-Proto $scheme;
            
            # 超时设置
            proxy_connect_timeout 60s;
            proxy_send_timeout 60s;
            proxy_read_timeout 60s;
        }
        
        # 健康检查
        location /health {
            proxy_pass http://ml_api/health;
        }
    }
}
"""
    
    with open('nginx.conf', 'w') as f:
        f.write(nginx_config)
    
    print("nginx.conf已创建")

# 创建Docker相关文件
create_dockerfile()
create_requirements_txt()
create_docker_compose()
create_nginx_config()

生产级Flask应用

python
def create_production_app():
    """
    创建生产级Flask应用
    """
    app_content = """
import os
import logging
from flask import Flask, request, jsonify
import tensorflow as tf
import numpy as np
from werkzeug.middleware.proxy_fix import ProxyFix
import time
from functools import wraps

# 配置日志
logging.basicConfig(
    level=logging.INFO,
    format='%(asctime)s - %(name)s - %(levelname)s - %(message)s'
)
logger = logging.getLogger(__name__)

app = Flask(__name__)
app.wsgi_app = ProxyFix(app.wsgi_app, x_for=1, x_proto=1, x_host=1, x_prefix=1)

# 全局变量
model = None
model_version = "1.0"

def load_model():
    \"\"\"加载模型\"\"\"
    global model
    try:
        model_path = os.environ.get('MODEL_PATH', './models/mnist_classifier/1')
        model = tf.keras.models.load_model(model_path)
        logger.info(f"模型加载成功: {model_path}")
        return True
    except Exception as e:
        logger.error(f"模型加载失败: {e}")
        return False

def require_model(f):
    \"\"\"装饰器:确保模型已加载\"\"\"
    @wraps(f)
    def decorated_function(*args, **kwargs):
        if model is None:
            return jsonify({'error': '模型未加载'}), 503
        return f(*args, **kwargs)
    return decorated_function

@app.before_first_request
def initialize():
    \"\"\"应用初始化\"\"\"
    load_model()

@app.route('/health', methods=['GET'])
def health():
    \"\"\"健康检查\"\"\"
    return jsonify({
        'status': 'healthy',
        'model_loaded': model is not None,
        'model_version': model_version,
        'timestamp': time.time()
    })

@app.route('/predict', methods=['POST'])
@require_model
def predict():
    \"\"\"预测接口\"\"\"
    try:
        start_time = time.time()
        
        data = request.get_json()
        if 'instances' not in data:
            return jsonify({'error': '缺少instances字段'}), 400
        
        instances = np.array(data['instances'])
        predictions = model.predict(instances)
        
        processing_time = time.time() - start_time
        
        logger.info(f"预测完成,耗时: {processing_time:.4f}s,样本数: {len(instances)}")
        
        return jsonify({
            'predictions': predictions.tolist(),
            'model_version': model_version,
            'processing_time': processing_time
        })
        
    except Exception as e:
        logger.error(f"预测失败: {e}")
        return jsonify({'error': str(e)}), 500

@app.errorhandler(404)
def not_found(error):
    return jsonify({'error': '接口不存在'}), 404

@app.errorhandler(500)
def internal_error(error):
    return jsonify({'error': '内部服务器错误'}), 500

if __name__ == '__main__':
    port = int(os.environ.get('PORT', 5000))
    debug = os.environ.get('FLASK_ENV') == 'development'
    
    app.run(host='0.0.0.0', port=port, debug=debug)
"""
    
    with open('app.py', 'w') as f:
        f.write(app_content)
    
    print("生产级Flask应用已创建: app.py")

create_production_app()

Kubernetes部署

Kubernetes配置文件

python
def create_k8s_deployment():
    """
    创建Kubernetes部署配置
    """
    deployment_yaml = """
apiVersion: apps/v1
kind: Deployment
metadata:
  name: ml-model-deployment
  labels:
    app: ml-model
spec:
  replicas: 3
  selector:
    matchLabels:
      app: ml-model
  template:
    metadata:
      labels:
        app: ml-model
    spec:
      containers:
      - name: ml-model
        image: ml-model:latest
        ports:
        - containerPort: 5000
        env:
        - name: MODEL_PATH
          value: "/app/models/mnist_classifier/1"
        resources:
          requests:
            memory: "512Mi"
            cpu: "250m"
          limits:
            memory: "1Gi"
            cpu: "500m"
        livenessProbe:
          httpGet:
            path: /health
            port: 5000
          initialDelaySeconds: 30
          periodSeconds: 10
        readinessProbe:
          httpGet:
            path: /health
            port: 5000
          initialDelaySeconds: 5
          periodSeconds: 5
---
apiVersion: v1
kind: Service
metadata:
  name: ml-model-service
spec:
  selector:
    app: ml-model
  ports:
    - protocol: TCP
      port: 80
      targetPort: 5000
  type: LoadBalancer
---
apiVersion: autoscaling/v2
kind: HorizontalPodAutoscaler
metadata:
  name: ml-model-hpa
spec:
  scaleTargetRef:
    apiVersion: apps/v1
    kind: Deployment
    name: ml-model-deployment
  minReplicas: 2
  maxReplicas: 10
  metrics:
  - type: Resource
    resource:
      name: cpu
      target:
        type: Utilization
        averageUtilization: 70
  - type: Resource
    resource:
      name: memory
      target:
        type: Utilization
        averageUtilization: 80
"""
    
    with open('k8s-deployment.yaml', 'w') as f:
        f.write(deployment_yaml)
    
    print("Kubernetes部署配置已创建: k8s-deployment.yaml")

def create_k8s_configmap():
    """
    创建Kubernetes ConfigMap
    """
    configmap_yaml = """
apiVersion: v1
kind: ConfigMap
metadata:
  name: ml-model-config
data:
  MODEL_VERSION: "1.0"
  LOG_LEVEL: "INFO"
  MAX_BATCH_SIZE: "32"
  TIMEOUT_SECONDS: "30"
"""
    
    with open('k8s-configmap.yaml', 'w') as f:
        f.write(configmap_yaml)
    
    print("Kubernetes ConfigMap已创建: k8s-configmap.yaml")

def create_k8s_ingress():
    """
    创建Kubernetes Ingress
    """
    ingress_yaml = """
apiVersion: networking.k8s.io/v1
kind: Ingress
metadata:
  name: ml-model-ingress
  annotations:
    nginx.ingress.kubernetes.io/rewrite-target: /
    nginx.ingress.kubernetes.io/ssl-redirect: "false"
    nginx.ingress.kubernetes.io/proxy-body-size: "10m"
spec:
  rules:
  - host: ml-api.example.com
    http:
      paths:
      - path: /
        pathType: Prefix
        backend:
          service:
            name: ml-model-service
            port:
              number: 80
"""
    
    with open('k8s-ingress.yaml', 'w') as f:
        f.write(ingress_yaml)
    
    print("Kubernetes Ingress已创建: k8s-ingress.yaml")

# 创建Kubernetes配置文件
create_k8s_deployment()
create_k8s_configmap()
create_k8s_ingress()

监控和日志

应用监控

python
def create_monitoring_app():
    """
    创建带监控的应用
    """
    monitoring_code = """
import time
import psutil
import threading
from collections import defaultdict, deque
from flask import Flask, jsonify
import json

class ModelMonitor:
    def __init__(self):
        self.metrics = {
            'request_count': defaultdict(int),
            'response_times': deque(maxlen=1000),
            'error_count': 0,
            'model_predictions': 0,
            'cpu_usage': deque(maxlen=100),
            'memory_usage': deque(maxlen=100),
            'start_time': time.time()
        }
        
        # 启动系统监控线程
        self.monitoring_thread = threading.Thread(target=self._monitor_system)
        self.monitoring_thread.daemon = True
        self.monitoring_thread.start()
    
    def _monitor_system(self):
        while True:
            self.metrics['cpu_usage'].append(psutil.cpu_percent())
            self.metrics['memory_usage'].append(psutil.virtual_memory().percent)
            time.sleep(5)
    
    def record_request(self, endpoint, response_time, success=True):
        self.metrics['request_count'][endpoint] += 1
        self.metrics['response_times'].append(response_time)
        
        if not success:
            self.metrics['error_count'] += 1
    
    def record_prediction(self):
        self.metrics['model_predictions'] += 1
    
    def get_metrics(self):
        uptime = time.time() - self.metrics['start_time']
        
        return {
            'uptime_seconds': uptime,
            'total_requests': sum(self.metrics['request_count'].values()),
            'request_count_by_endpoint': dict(self.metrics['request_count']),
            'error_count': self.metrics['error_count'],
            'model_predictions': self.metrics['model_predictions'],
            'average_response_time': sum(self.metrics['response_times']) / len(self.metrics['response_times']) if self.metrics['response_times'] else 0,
            'current_cpu_usage': self.metrics['cpu_usage'][-1] if self.metrics['cpu_usage'] else 0,
            'current_memory_usage': self.metrics['memory_usage'][-1] if self.metrics['memory_usage'] else 0,
            'average_cpu_usage': sum(self.metrics['cpu_usage']) / len(self.metrics['cpu_usage']) if self.metrics['cpu_usage'] else 0,
            'average_memory_usage': sum(self.metrics['memory_usage']) / len(self.metrics['memory_usage']) if self.metrics['memory_usage'] else 0
        }

# 全局监控实例
monitor = ModelMonitor()

def create_monitored_app():
    app = Flask(__name__)
    
    @app.before_request
    def before_request():
        request.start_time = time.time()
    
    @app.after_request
    def after_request(response):
        if hasattr(request, 'start_time'):
            response_time = time.time() - request.start_time
            success = response.status_code < 400
            monitor.record_request(request.endpoint, response_time, success)
        return response
    
    @app.route('/metrics')
    def get_metrics():
        return jsonify(monitor.get_metrics())
    
    @app.route('/health')
    def health():
        metrics = monitor.get_metrics()
        
        # 健康检查逻辑
        is_healthy = (
            metrics['current_cpu_usage'] < 90 and
            metrics['current_memory_usage'] < 90 and
            metrics['average_response_time'] < 5.0
        )
        
        return jsonify({
            'status': 'healthy' if is_healthy else 'unhealthy',
            'checks': {
                'cpu_ok': metrics['current_cpu_usage'] < 90,
                'memory_ok': metrics['current_memory_usage'] < 90,
                'response_time_ok': metrics['average_response_time'] < 5.0
            }
        })
    
    return app
"""
    
    with open('monitoring.py', 'w') as f:
        f.write(monitoring_code)
    
    print("监控模块已创建: monitoring.py")

create_monitoring_app()

Prometheus集成

python
def create_prometheus_config():
    """
    创建Prometheus配置
    """
    prometheus_yml = """
global:
  scrape_interval: 15s

scrape_configs:
  - job_name: 'ml-model'
    static_configs:
      - targets: ['ml-model-service:80']
    metrics_path: '/metrics'
    scrape_interval: 10s
    
  - job_name: 'kubernetes-pods'
    kubernetes_sd_configs:
      - role: pod
    relabel_configs:
      - source_labels: [__meta_kubernetes_pod_annotation_prometheus_io_scrape]
        action: keep
        regex: true
"""
    
    with open('prometheus.yml', 'w') as f:
        f.write(prometheus_yml)
    
    print("Prometheus配置已创建: prometheus.yml")

def create_grafana_dashboard():
    """
    创建Grafana仪表板配置
    """
    dashboard_json = {
        "dashboard": {
            "title": "ML Model Monitoring",
            "panels": [
                {
                    "title": "Request Rate",
                    "type": "graph",
                    "targets": [
                        {
                            "expr": "rate(ml_requests_total[5m])",
                            "legendFormat": "Requests/sec"
                        }
                    ]
                },
                {
                    "title": "Response Time",
                    "type": "graph",
                    "targets": [
                        {
                            "expr": "ml_response_time_seconds",
                            "legendFormat": "Response Time"
                        }
                    ]
                },
                {
                    "title": "Error Rate",
                    "type": "graph",
                    "targets": [
                        {
                            "expr": "rate(ml_errors_total[5m])",
                            "legendFormat": "Errors/sec"
                        }
                    ]
                }
            ]
        }
    }
    
    with open('grafana-dashboard.json', 'w') as f:
        json.dump(dashboard_json, f, indent=2)
    
    print("Grafana仪表板配置已创建: grafana-dashboard.json")

create_prometheus_config()
create_grafana_dashboard()

性能优化

模型优化

python
def optimize_model_for_inference(model):
    """
    优化模型用于推理
    """
    # 1. 转换为推理模式
    @tf.function
    def inference_func(x):
        return model(x, training=False)
    
    # 2. 创建具体函数
    concrete_func = inference_func.get_concrete_function(
        tf.TensorSpec(shape=[None, 784], dtype=tf.float32)
    )
    
    # 3. 冻结图
    frozen_func = convert_variables_to_constants_v2(concrete_func)
    
    return frozen_func

def batch_prediction_optimization():
    """
    批量预测优化
    """
    class BatchPredictor:
        def __init__(self, model, max_batch_size=32, max_wait_time=0.1):
            self.model = model
            self.max_batch_size = max_batch_size
            self.max_wait_time = max_wait_time
            self.batch_queue = []
            self.results_queue = {}
            self.batch_id = 0
            
        def predict(self, input_data):
            # 添加到批次队列
            batch_id = self.batch_id
            self.batch_id += 1
            
            self.batch_queue.append((batch_id, input_data))
            
            # 检查是否需要处理批次
            if len(self.batch_queue) >= self.max_batch_size:
                self._process_batch()
            
            # 等待结果
            start_time = time.time()
            while batch_id not in self.results_queue:
                if time.time() - start_time > self.max_wait_time:
                    self._process_batch()
                time.sleep(0.001)
            
            result = self.results_queue.pop(batch_id)
            return result
        
        def _process_batch(self):
            if not self.batch_queue:
                return
            
            # 准备批次数据
            batch_ids = []
            batch_data = []
            
            for batch_id, data in self.batch_queue:
                batch_ids.append(batch_id)
                batch_data.append(data)
            
            # 清空队列
            self.batch_queue = []
            
            # 批量预测
            batch_input = np.array(batch_data)
            batch_predictions = self.model.predict(batch_input)
            
            # 存储结果
            for i, batch_id in enumerate(batch_ids):
                self.results_queue[batch_id] = batch_predictions[i]
    
    return BatchPredictor

def create_model_cache():
    """
    创建模型缓存
    """
    import hashlib
    from functools import lru_cache
    
    class ModelCache:
        def __init__(self, max_size=1000):
            self.cache = {}
            self.max_size = max_size
            self.access_count = {}
        
        def _hash_input(self, input_data):
            return hashlib.md5(input_data.tobytes()).hexdigest()
        
        def get_prediction(self, model, input_data):
            input_hash = self._hash_input(input_data)
            
            if input_hash in self.cache:
                self.access_count[input_hash] += 1
                return self.cache[input_hash]
            
            # 计算预测
            prediction = model.predict(input_data)
            
            # 缓存管理
            if len(self.cache) >= self.max_size:
                # 移除最少使用的项
                least_used = min(self.access_count.items(), key=lambda x: x[1])
                del self.cache[least_used[0]]
                del self.access_count[least_used[0]]
            
            # 添加到缓存
            self.cache[input_hash] = prediction
            self.access_count[input_hash] = 1
            
            return prediction
    
    return ModelCache()

安全考虑

API安全

python
def create_secure_api():
    """
    创建安全的API
    """
    secure_api_code = """
import jwt
import hashlib
import time
from functools import wraps
from flask import Flask, request, jsonify
import ratelimit
from werkzeug.security import check_password_hash

app = Flask(__name__)
app.config['SECRET_KEY'] = 'your-secret-key-here'

# API密钥存储(实际应用中应使用数据库)
API_KEYS = {
    'client1': 'hashed_api_key_1',
    'client2': 'hashed_api_key_2'
}

# 速率限制
from flask_limiter import Limiter
from flask_limiter.util import get_remote_address

limiter = Limiter(
    app,
    key_func=get_remote_address,
    default_limits=["100 per hour"]
)

def require_api_key(f):
    @wraps(f)
    def decorated_function(*args, **kwargs):
        api_key = request.headers.get('X-API-Key')
        
        if not api_key:
            return jsonify({'error': '缺少API密钥'}), 401
        
        # 验证API密钥
        client_id = request.headers.get('X-Client-ID')
        if not client_id or client_id not in API_KEYS:
            return jsonify({'error': '无效的客户端ID'}), 401
        
        expected_key = API_KEYS[client_id]
        if not check_password_hash(expected_key, api_key):
            return jsonify({'error': '无效的API密钥'}), 401
        
        return f(*args, **kwargs)
    return decorated_function

def validate_input(f):
    @wraps(f)
    def decorated_function(*args, **kwargs):
        data = request.get_json()
        
        # 输入验证
        if not data or 'instances' not in data:
            return jsonify({'error': '无效的输入数据'}), 400
        
        instances = data['instances']
        
        # 检查数据类型和大小
        if not isinstance(instances, list):
            return jsonify({'error': 'instances必须是列表'}), 400
        
        if len(instances) > 100:  # 限制批次大小
            return jsonify({'error': '批次大小不能超过100'}), 400
        
        # 检查每个实例的格式
        for instance in instances:
            if not isinstance(instance, list) or len(instance) != 784:
                return jsonify({'error': '每个实例必须是长度为784的列表'}), 400
        
        return f(*args, **kwargs)
    return decorated_function

@app.route('/predict', methods=['POST'])
@limiter.limit("10 per minute")
@require_api_key
@validate_input
def secure_predict():
    try:
        data = request.get_json()
        instances = np.array(data['instances'])
        
        # 输入范围检查
        if np.any(instances < 0) or np.any(instances > 1):
            return jsonify({'error': '输入值必须在0-1范围内'}), 400
        
        # 进行预测
        predictions = model.predict(instances)
        
        return jsonify({
            'predictions': predictions.tolist(),
            'timestamp': time.time()
        })
        
    except Exception as e:
        # 不暴露内部错误信息
        return jsonify({'error': '预测失败'}), 500

@app.errorhandler(429)
def ratelimit_handler(e):
    return jsonify({'error': '请求过于频繁,请稍后重试'}), 429
"""
    
    with open('secure_api.py', 'w') as f:
        f.write(secure_api_code)
    
    print("安全API已创建: secure_api.py")

create_secure_api()

总结

本章详细介绍了TensorFlow模型的各种部署方式:

关键要点:

  1. 部署选择:根据需求选择合适的部署方式
  2. TensorFlow Serving:高性能生产环境部署
  3. 容器化:使用Docker确保环境一致性
  4. Kubernetes:大规模自动化部署和管理
  5. 监控日志:完善的监控和日志系统
  6. 性能优化:模型优化和推理加速
  7. 安全考虑:API安全和访问控制

最佳实践:

  • 选择合适的部署架构
  • 实施完善的监控和日志
  • 考虑安全性和访问控制
  • 进行性能测试和优化
  • 建立CI/CD流程
  • 准备灾难恢复方案
  • 定期更新和维护

模型部署是一个复杂的工程问题,需要综合考虑性能、可靠性、安全性和可维护性等多个方面。

本站内容仅供学习和研究使用。