装饰器是Python最强大的特性之一,用好装饰器能让你的代码更简洁、更易复用。很多框架(Flask、Django等)大量使用装饰器。学会写装饰器,你也能写出专业级别的代码。

一、装饰器基础

函数装饰器

# 装饰器本质:高阶函数,接收函数作为参数,返回新函数

def my_decorator(func):
    """简单装饰器"""
    def wrapper(*args, **kwargs):
        print("调用函数前")
        result = func(*args, **kwargs)
        print("调用函数后")
        return result
    return wrapper

@my_decorator
def say_hello():
    print("Hello!")

# 等价于
say_hello = my_decorator(say_hello)

# 调用
say_hello()
# 输出:
# 调用函数前
# Hello!
# 调用函数后

带参数的装饰器

def repeat(times):
    """带参数的装饰器工厂"""
    def decorator(func):
        def wrapper(*args, **kwargs):
            for _ in range(times):
                result = func(*args, **kwargs)
            return result
        return wrapper
    return decorator

@repeat(times=3)
def greet(name):
    print(f"你好, {name}!")

greet("张三")
# 打印3次 "你好, 张三!"

保存原函数信息

from functools import wraps

def my_decorator(func):
    @wraps(func)  # 保留原函数信息
    def wrapper(*args, **kwargs):
        return func(*args, **kwargs)
    return wrapper

@my_decorator
def original():
    """原始函数的文档"""
    pass

print(original.__name__)  # original(而不是wrapper)
print(original.__doc__)    # 原始函数的文档

二、实战:计时装饰器

import time
from functools import wraps

def timer(func):
    """计时装饰器"""
    @wraps(func)
    def wrapper(*args, **kwargs):
        start = time.time()
        result = func(*args, **kwargs)
        end = time.time()
        print(f"{func.__name__} 执行耗时: {end - start:.2f}秒")
        return result
    return wrapper

def async_timer(func):
    """异步函数计时"""
    @wraps(func)
    async def wrapper(*args, **kwargs):
        start = time.time()
        result = await func(*args, **kwargs)
        end = time.time()
        print(f"{func.__name__} 执行耗时: {end - start:.2f}秒")
        return result
    return wrapper

@timer
def slow_function():
    time.sleep(1)
    return "完成"

# 使用
result = slow_function()
# slow_function 执行耗时: 1.00秒

三、实战:重试装饰器

import time
import functools
from typing import Type, Tuple

def retry(
    max_attempts=3, 
    delay=1, 
    backoff=2,
    exceptions=(Exception,)
):
    """重试装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            attempts = 0
            current_delay = delay
            
            while attempts < max_attempts:
                try:
                    return func(*args, **kwargs)
                except exceptions as e:
                    attempts += 1
                    if attempts >= max_attempts:
                        raise
                    
                    print(f"尝试 {attempts} 失败: {e}, {current_delay}秒后重试...")
                    time.sleep(current_delay)
                    current_delay *= backoff
            
        return wrapper
    return decorator

@retry(max_attempts=3, delay=1, exceptions=(ConnectionError, TimeoutError))
def fetch_data(url):
    """模拟可能失败的网络请求"""
    import random
    if random.random() < 0.7:  # 70%概率失败
        raise ConnectionError("网络连接失败")
    return {"data": "success"}

# 使用
try:
    data = fetch_data("https://api.example.com")
except ConnectionError as e:
    print(f"最终失败: {e}")

四、实战:缓存装饰器

import functools
import time
import hashlib
import json

def cache(ttl=300):
    """缓存装饰器(基于函数参数)"""
    def decorator(func):
        cache_data = {}
        
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 生成缓存键
            key = (args, tuple(sorted(kwargs.items())))
            key_str = json.dumps(key, sort_keys=True, default=str)
            key_hash = hashlib.md5(key_str.encode()).hexdigest()
            
            # 检查缓存
            if key_hash in cache_data:
                cached_time, cached_result = cache_data[key_hash]
                if time.time() - cached_time < ttl:
                    print(f"[缓存命中] {func.__name__}")
                    return cached_result
            
            # 执行函数
            result = func(*args, **kwargs)
            
            # 保存缓存
            cache_data[key_hash] = (time.time(), result)
            return result
        
        wrapper.cache_clear = lambda: cache_data.clear()
        return wrapper
    return decorator

@cache(ttl=60)
def expensive_computation(n):
    """耗时计算"""
    time.sleep(2)  # 模拟耗时操作
    return n ** 2

# 使用
print(expensive_computation(100))  # 第一次,计算
print(expensive_computation(100))  # 第二次,缓存命中
expensive_computation.cache_clear()  # 清除缓存

五、实战:日志装饰器

import logging
import functools
import json
from datetime import datetime

logger = logging.getLogger(__name__)

def log_calls(level=logging.INFO):
    """记录函数调用的装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 记录调用
            logger.log(level, f"[调用] {func.__name__}(args={args}, kwargs={kwargs})")
            
            start_time = datetime.now()
            try:
                result = func(*args, **kwargs)
                duration = (datetime.now() - start_time).total_seconds()
                logger.log(level, f"[返回] {func.__name__} -> {result!r} (耗时: {duration:.3f}s)")
                return result
            except Exception as e:
                logger.error(f"[异常] {func.__name__}: {e}")
                raise
        
        @wraps(func)
        def async_wrapper(*args, **kwargs):
            # 异步版本
            import asyncio
            logger.log(level, f"[异步调用] {func.__name__}")
            
            async def run():
                try:
                    result = await func(*args, **kwargs)
                    logger.log(level, f"[异步返回] {func.__name__} -> 完成")
                    return result
                except Exception as e:
                    logger.error(f"[异步异常] {func.__name__}: {e}")
                    raise
            
            return run()
        
        if functools.iscoroutinefunction(func):
            return async_wrapper
        return wrapper
    return decorator

@log_calls()
def process_data(data):
    return [x * 2 for x in data]

@log_calls()
async def fetch_async(url):
    return await asyncio.get_event_loop().run_in_executor(None, lambda: f"响应: {url}")

# 使用
result = process_data([1, 2, 3])
# [调用] process_data(args=([1, 2, 3],), kwargs={})
# [返回] process_data -> [2, 4, 6] (耗时: 0.000s)

六、实战:权限验证装饰器

from functools import wraps
from typing import List, Callable

# 模拟用户和权限
current_user = {'id': 1, 'role': 'admin', 'permissions': ['read', 'write', 'delete']}

def require_permission(*permissions):
    """权限验证装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            user = kwargs.get('user') or current_user
            
            for permission in permissions:
                if permission not in user.get('permissions', []):
                    raise PermissionError(f"缺少权限: {permission}")
            
            return func(*args, **kwargs)
        return wrapper
    return decorator

def require_role(*roles):
    """角色验证装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            user = kwargs.get('user') or current_user
            
            if user.get('role') not in roles:
                raise PermissionError(f"角色验证失败,需要: {roles}")
            
            return func(*args, **kwargs)
        return wrapper
    return decorator

class RateLimit:
    """限流器"""
    def __init__(self, max_calls: int, period: float):
        self.max_calls = max_calls
        self.period = period
        self.calls = []
    
    def __call__(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            now = time.time()
            self.calls = [t for t in self.calls if now - t < self.period]
            
            if len(self.calls) >= self.max_calls:
                raise Exception(f"请求过于频繁,请{self.period}秒后重试")
            
            self.calls.append(now)
            return func(*args, **kwargs)
        return wrapper

# 使用
@require_permission('read', 'write')
def update_user(user_id, data, user=None):
    return {"success": True}

@require_role('admin', 'superadmin')
def delete_user(user_id, user=None):
    return {"success": True}

@RateLimit(max_calls=10, period=60)
def api_endpoint():
    return {"message": "请求成功"}

# 测试
try:
    update_user(1, {"name": "张三"})
except PermissionError as e:
    print(f"权限错误: {e}")

七、实战:类装饰器

from functools import wraps

class Singleton:
    """单例装饰器"""
    _instances = {}
    
    def __init__(self, cls):
        self._cls = cls
        self._instances[cls] = None
    
    def __call__(self, *args, **kwargs):
        if self._instances.get(self._cls) is None:
            self._instances[self._cls] = self._cls(*args, **kwargs)
        return self._instances[self._cls]

@Singleton
class Database:
    def __init__(self):
        print("数据库连接已建立")
    
    def query(self, sql):
        return f"执行: {sql}"

# 使用 - 多次实例化返回同一对象
db1 = Database()
db2 = Database()
print(db1 is db2)  # True

class Typed:
    """类型检查装饰器"""
    def __init__(self, expected_type):
        self.expected = expected_type
    
    def __call__(self, func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 检查位置参数
            for i, arg in enumerate(args):
                if not isinstance(arg, self.expected):
                    raise TypeError(f"参数{i}类型错误: 期望{self.expected}, 得到{type(arg)}")
            
            # 检查关键字参数
            for name, value in kwargs.items():
                if not isinstance(value, self.expected):
                    raise TypeError(f"参数'{name}'类型错误")
            
            return func(*args, **kwargs)
        return wrapper

@Typed(int)
def add(a, b):
    return a + b

add(1, 2)  # 正常
# add("1", "2")  # TypeError

八、实战:表单验证装饰器

from functools import wraps
from typing import Any, Callable

class ValidationError(Exception):
    pass

def validate(**validators):
    """参数验证装饰器"""
    def decorator(func):
        @wraps(func)
        def wrapper(*args, **kwargs):
            # 获取函数参数
            import inspect
            sig = inspect.signature(func)
            bound = sig.bind(*args, **kwargs)
            bound.apply_defaults()
            
            # 验证
            for name, value in bound.arguments.items():
                if name in validators:
                    validator = validators[name]
                    
                    if isinstance(validator, tuple):
                        # (类型, 其他检查)
                        expected_type, *checks = validator
                        if not isinstance(value, expected_type):
                            raise ValidationError(f"{name}必须是{expected_type.__name__}类型")
                        
                        for check in checks:
                            if not check(value):
                                raise ValidationError(f"{name}验证失败")
                    else:
                        # 直接是类型
                        if not isinstance(value, validator):
                            raise ValidationError(f"{name}必须是{validator.__name__}类型")
            
            return func(*args, **kwargs)
        return wrapper
    return decorator

def validate_email(email):
    """验证邮箱"""
    import re
    return bool(re.match(r'^[\w\.-]+@[\w\.-]+\.\w+$', email))

def validate_positive(num):
    """验证正数"""
    return num > 0

@validate(
    name=str,
    age=(int, lambda x: 0 <= x <= 150),
    email=(str, validate_email),
    salary=(float, validate_positive)
)
def create_employee(name, age, email, salary):
    return {"name": name, "age": age, "email": email, "salary": salary}

# 使用
try:
    emp = create_employee("张三", 25, "zhangsan@example.com", 10000)
    print(emp)
except ValidationError as e:
    print(f"验证失败: {e}")

九、装饰器链

# 多个装饰器叠加使用
@timer
@retry(max_attempts=3)
@cache(ttl=60)
def complex_operation(data):
    """复杂操作"""
    return data

# 等价于
complex_operation = timer(retry(max_attempts=3)(cache(ttl=60)(complex_operation)))

# 注意:从下往上执行
# 1. cache(ttl=60)(complex_operation) 先执行
# 2. retry 装饰 cache 的结果
# 3. timer 装饰 retry 的结果

# 装饰顺序影响:
# @a
# @b
# def f(): pass
# 等价于: f = a(b(f))

十、实战:Flask风格路由装饰器

from functools import wraps
from collections import defaultdict

class Router:
    """简易路由装饰器"""
    
    def __init__(self):
        self.routes = defaultdict(list)
    
    def route(self, path, methods=None):
        """路由装饰器"""
        methods = methods or ['GET']
        
        def decorator(func):
            @wraps(func)
            def wrapper(*args, **kwargs):
                return func(*args, **kwargs)
            
            wrapper._route_path = path
            wrapper._route_methods = methods
            self.routes[path].append(wrapper)
            return wrapper
        return decorator
    
    def match(self, path, method='GET'):
        """匹配路由"""
        if path in self.routes:
            for func in self.routes[path]:
                if method in func._route_methods:
                    return func
        return None

# 创建路由器
router = Router()

@router.route('/users', methods=['GET'])
def get_users():
    return {"users": []}

@router.route('/users', methods=['POST'])
def create_user():
    return {"created": True}

@router.route('/users/<int:user_id>')
def get_user(user_id):
    return {"user_id": user_id}

# 使用
def dispatch(path, method='GET'):
    handler = router.match(path, method)
    if handler:
        # 解析路径参数(简化版)
        return handler(user_id=1)
    return {"error": "Not Found"}

print(dispatch('/users', 'GET'))
print(dispatch('/users', 'POST'))

总结

装饰器使用要点:

  1. 基础语法 - @decorator 本质是 func = decorator(func)
  2. @wraps - 保留原函数信息
  3. 带参数 - 装饰器工厂返回装饰器
  4. 通用签名 - *args, **kwargs 处理任意参数
  5. 组合使用 - 多个装饰器按从下到上顺序执行
  6. 常见场景 - 计时、重试、缓存、日志、权限、验证

装饰器是Python的瑞士军刀,用好它代码质量提升一个档次!

更多推荐