Python 单元测试:完整指南与示例

使用 pytest 进行 Python 测试,TDD,模拟和覆盖率

目录

单元测试确保你的 Python 代码正常工作,并且随着项目的演进继续正常运行。
本全面指南涵盖了关于 Python 单元测试 的所有内容,从基本概念到高级技术。

Python 单元测试

为什么单元测试很重要

单元测试为 Python 开发者提供了诸多好处:

  • 早期发现错误:在代码进入生产环境之前发现错误
  • 代码质量:迫使你编写模块化、可测试的代码
  • 重构信心:在知道测试会捕捉回归的情况下安全地进行更改
  • 文档:测试作为代码应如何工作的可执行文档
  • 更快的开发:自动化测试比手动测试更快
  • 更好的设计:编写可测试的代码会导致更好的架构

理解单元测试基础

什么是单元测试?

单元测试验证应用程序中最小的可测试部分(通常是函数或方法)的隔离性。它应该:

  • 快速:在毫秒内运行
  • 隔离:独立于其他测试和外部系统
  • 可重复:每次运行结果相同
  • 自我验证:无需人工检查即可明确通过或失败
  • 及时:在代码编写之前或同时编写

测试金字塔

一个健康的测试套件遵循测试金字塔:

           /\
          /  \     E2E 测试(少量)
         /____\
        /      \   集成测试(一些)
       /_____ ___\
      /          \ 单元测试(很多)
     /_____ __ ___\

单元测试是基础——它们数量众多、快速,并能提供快速反馈。

Python 测试框架比较

unittest:内置框架

Python 的标准库包含 unittest,灵感来自 JUnit:

import unittest

class TestCalculator(unittest.TestCase):
    def setUp(self):
        """每个测试前运行"""
        self.calc = Calculator()
    
    def tearDown(self):
        """每个测试后运行"""
        self.calc = None
    
    def test_addition(self):
        result = self.calc.add(2, 3)
        self.assertEqual(result, 5)
    
    def test_division_by_zero(self):
        with self.assertRaises(ZeroDivisionError):
            self.calc.divide(10, 0)

if __name__ == '__main__':
    unittest.main()

优点:

  • 内置,无需安装
  • 对熟悉 xUnit 框架的开发者友好
  • 企业友好,已建立

缺点:

  • 语法冗长,有样板代码
  • 需要类来组织测试
  • 固件管理不够灵活

pytest:现代选择

pytest 是最受欢迎的第三方测试框架:

import pytest

def test_addition():
    calc = Calculator()
    assert calc.add(2, 3) == 5

def test_division_by_zero():
    calc = Calculator()
    with pytest.raises(ZeroDivisionError):
        calc.divide(10, 0)

@pytest.mark.parametrize("a,b,expected", [
    (2, 3, 5),
    (-1, 1, 0),
    (0, 0, 0),
])
def test_addition_parametrized(a, b, expected):
    calc = Calculator()
    assert calc.add(a, b) == expected

优点:

  • 简单、Pythonic 的语法
  • 强大的固件系统
  • 优秀的插件生态系统
  • 更好的错误报告
  • 内置参数化测试

缺点:

  • 需要安装
  • 对其他语言开发者来说不够熟悉

安装:

pip install pytest pytest-cov pytest-mock

编写你的第一个单元测试

让我们使用测试驱动开发(TDD)从头开始构建一个简单的示例。如果你是 Python 新手,或者需要语法和语言功能的快速参考,请查看我们的 Python 快速参考,以全面了解 Python 基础知识。

示例:字符串实用函数

第一步:先写测试(红色)

创建 test_string_utils.py

import pytest
from string_utils import reverse_string, is_palindrome, count_vowels

def test_reverse_string():
    assert reverse_string("hello") == "olleh"
    assert reverse_string("") == ""
    assert reverse_string("a") == "a"

def test_is_palindrome():
    assert is_palindrome("racecar") == True
    assert is_palindrome("hello") == False
    assert is_palindrome("") == True
    assert is_palindrome("A man a plan a canal Panama") == True

def test_count_vowels():
    assert count_vowels("hello") == 2
    assert count_vowels("HELLO") == 2
    assert count_vowels("xyz") == 0
    assert count_vowels("") == 0

第二步:编写最小代码以通过(绿色)

创建 string_utils.py

def reverse_string(s: str) -> str:
    """反转字符串"""
    return s[::-1]

def is_palindrome(s: str) -> bool:
    """检查字符串是否是回文(忽略大小写和空格)"""
    cleaned = ''.join(s.lower().split())
    return cleaned == cleaned[::-1]

def count_vowels(s: str) -> int:
    """计算字符串中的元音字母数量"""
    return sum(1 for char in s.lower() if char in 'aeiou')

第三步:运行测试

pytest test_string_utils.py -v

输出:

test_string_utils.py::test_reverse_string PASSED
test_string_utils.py::test_is_palindrome PASSED
test_string_utils.py::test_count_vowels PASSED

高级测试技术

使用固件进行测试设置

固件提供可重用的测试设置和清理:

import pytest
from database import Database

@pytest.fixture
def db():
    """创建测试数据库"""
    database = Database(":memory:")
    database.create_tables()
    yield database  # 提供给测试
    database.close()  # 测试后清理

@pytest.fixture
def sample_users(db):
    """向数据库添加示例用户"""
    db.add_user("Alice", "alice@example.com")
    db.add_user("Bob", "bob@example.com")
    return db

def test_get_user(sample_users):
    user = sample_users.get_user_by_email("alice@example.com")
    assert user.name == "Alice"

def test_user_count(sample_users):
    assert sample_users.count_users() == 2

固件作用域

使用作用域控制固件生命周期:

@pytest.fixture(scope="function")  # 默认:每个测试运行一次
def func_fixture():
    return create_resource()

@pytest.fixture(scope="class")  # 每个测试类运行一次
def class_fixture():
    return create_resource()

@pytest.fixture(scope="module")  # 每个模块运行一次
def module_fixture():
    return create_resource()

@pytest.fixture(scope="session")  # 每个测试会话运行一次
def session_fixture():
    return create_expensive_resource()

模拟外部依赖

使用模拟来隔离代码与外部依赖:

from unittest.mock import Mock, patch, MagicMock
import requests

class WeatherService:
    def get_temperature(self, city):
        response = requests.get(f"https://api.weather.com/{city}")
        return response.json()["temp"]

# 使用模拟测试
def test_get_temperature():
    service = WeatherService()
    
    # 模拟 requests.get 函数
    with patch('requests.get') as mock_get:
        # 配置模拟响应
        mock_response = Mock()
        mock_response.json.return_value = {"temp": 72}
        mock_get.return_value = mock_response
        
        # 测试
        temp = service.get_temperature("Boston")
        assert temp == 72
        
        # 验证调用
        mock_get.assert_called_once_with("https://api.weather.com/Boston")

使用 pytest-mock 插件

pytest-mock 提供更简洁的语法:

def test_get_temperature(mocker):
    service = WeatherService()
    
    # 使用 pytest-mock 模拟
    mock_response = mocker.Mock()
    mock_response.json.return_value = {"temp": 72}
    mocker.patch('requests.get', return_value=mock_response)
    
    temp = service.get_temperature("Boston")
    assert temp == 72

参数化测试

高效测试多种场景:

import pytest

@pytest.mark.parametrize("input,expected", [
    ("", True),
    ("a", True),
    ("ab", False),
    ("aba", True),
    ("racecar", True),
    ("hello", False),
])
def test_is_palindrome_parametrized(input, expected):
    assert is_palindrome(input) == expected

@pytest.mark.parametrize("number,is_even", [
    (0, True),
    (1, False),
    (2, True),
    (-1, False),
    (-2, True),
])
def test_is_even(number, is_even):
    assert (number % 2 == 0) == is_even

测试异常和错误处理

import pytest

def divide(a, b):
    if b == 0:
        raise ValueError("不能除以零")
    return a / b

def test_divide_by_zero():
    with pytest.raises(ValueError, match="不能除以零"):
        divide(10, 0)

def test_divide_by_zero_with_message():
    with pytest.raises(ValueError) as exc_info:
        divide(10, 0)
    assert "零" in str(exc_info.value).lower()

# 测试没有异常被抛出
def test_divide_success():
    result = divide(10, 2)
    assert result == 5.0

测试异步代码

测试异步代码对于现代 Python 应用程序至关重要,尤其是在处理 API、数据库或 AI 服务时。以下是测试异步函数的方法:

import pytest
import asyncio

async def fetch_data(url):
    """异步函数获取数据"""
    await asyncio.sleep(0.1)  # 模拟 API 调用
    return {"status": "success"}

@pytest.mark.asyncio
async def test_fetch_data():
    result = await fetch_data("https://api.example.com")
    assert result["status"] == "success"

@pytest.mark.asyncio
async def test_fetch_data_with_mock(mocker):
    # 模拟异步函数
    mock_fetch = mocker.AsyncMock(return_value={"status": "mocked"})
    mocker.patch('module.fetch_data', mock_fetch)
    
    result = await fetch_data("https://api.example.com")
    assert result["status"] == "mocked"

对于测试异步代码与 AI 服务的实际示例,请参阅我们的指南 将 Ollama 与 Python 集成,其中包含 LLM 交互的测试策略。

代码覆盖率

衡量你的代码中有多少被测试:

使用 pytest-cov

# 运行测试并获取覆盖率
pytest --cov=myproject tests/

# 生成 HTML 报告
pytest --cov=myproject --cov-report=html tests/

# 显示缺失的行
pytest --cov=myproject --cov-report=term-missing tests/

覆盖率配置

创建 .coveragerc

[run]
source = myproject
omit =
    */tests/*
    */venv/*
    */__pycache__/*

[report]
exclude_lines =
    pragma: no cover
    def __repr__
    raise AssertionError
    raise NotImplementedError
    if __name__ == .__main__.:
    if TYPE_CHECKING:
    @abstractmethod

覆盖率最佳实践

  1. 目标是 80%+ 的覆盖率 对于关键代码路径
  2. 不要执着于 100% —— 专注于有意义的测试
  3. 测试边界情况 而不仅仅是正常路径
  4. 从覆盖率报告中排除样板代码
  5. 将覆盖率作为指南而非目标

测试组织和项目结构

推荐结构

myproject/
├── myproject/
│   ├── __init__.py
│   ├── module1.py
│   ├── module2.py
│   └── utils.py
├── tests/
│   ├── __init__.py
│   ├── conftest.py          # 共享固件
│   ├── test_module1.py
│   ├── test_module2.py
│   ├── test_utils.py
│   └── integration/
│       ├── __init__.py
│       └── test_integration.py
├── pytest.ini
├── requirements.txt
└── requirements-dev.txt

conftest.py 用于共享固件

# tests/conftest.py
import pytest
from myproject.database import Database

@pytest.fixture(scope="session")
def test_db():
    """会话级测试数据库"""
    db = Database(":memory:")
    db.create_schema()
    yield db
    db.close()

@pytest.fixture
def clean_db(test_db):
    """每个测试的干净数据库"""
    test_db.clear_all_tables()
    return test_db

pytest.ini 配置

[pytest]
testpaths = tests
python_files = test_*.py
python_classes = Test*
python_functions = test_*
addopts =
    -v
    --strict-markers
    --cov=myproject
    --cov-report=term-missing
    --cov-report=html
markers =
    slow: 标记为慢速的测试
    integration: 标记为集成测试
    unit: 标记为单元测试

单元测试最佳实践

1. 遵循 AAA 模式

Arrange-Act-Assert 使测试清晰:

def test_user_creation():
    # Arrange
    username = "john_doe"
    email = "john@example.com"
    
    # Act
    user = User(username, email)
    
    # Assert
    assert user.username == username
    assert user.email == email
    assert user.is_active == True

2. 每个测试一个断言(指导原则,而非规则)

# 好:聚焦的测试
def test_user_username():
    user = User("john_doe", "john@example.com")
    assert user.username == "john_doe"

def test_user_email():
    user = User("john_doe", "john@example.com")
    assert user.email == "john@example.com"

# 也接受:相关断言
def test_user_creation():
    user = User("john_doe", "john@example.com")
    assert user.username == "john_doe"
    assert user.email == "john@example.com"
    assert isinstance(user.created_at, datetime)

3. 使用描述性的测试名称

# 差
def test_user():
    pass

# 好
def test_user_creation_with_valid_data():
    pass

def test_user_creation_fails_with_invalid_email():
    pass

def test_user_password_is_hashed_after_setting():
    pass

4. 测试边界情况和边缘情况

def test_age_validation():
    # 有效情况
    assert validate_age(0) == True
    assert validate_age(18) == True
    assert validate_age(120) == True
    
    # 边界情况
    assert validate_age(-1) == False
    assert validate_age(121) == False
    
    # 边缘情况
    with pytest.raises(TypeError):
        validate_age("18")
    with pytest.raises(TypeError):
        validate_age(None)

5. 保持测试独立

# 差:测试依赖于顺序
counter = 0

def test_increment():
    global counter
    counter += 1
    assert counter == 1

def test_increment_again():  # 如果单独运行会失败
    global counter
    counter += 1
    assert counter == 2

# 好:测试独立
def test_increment():
    counter = Counter()
    counter.increment()
    assert counter.value == 1

def test_increment_multiple_times():
    counter = Counter()
    counter.increment()
    counter.increment()
    assert counter.value == 2

6. 不要测试实现细节

# 差:测试实现
def test_sort_uses_quicksort():
    sorter = Sorter()
    assert sorter.algorithm == "quicksort"

# 好:测试行为
def test_sort_returns_sorted_list():
    sorter = Sorter()
    result = sorter.sort([3, 1, 2])
    assert result == [1, 2, 3]

7. 测试真实世界使用情况

当测试处理或转换数据的库时,专注于真实世界的场景。例如,如果你正在处理网页抓取或内容转换,请查看我们的指南 使用 Python 将 HTML 转换为 Markdown,其中包含不同转换库的测试策略和基准测试比较。

持续集成集成

GitHub Actions 示例

# .github/workflows/tests.yml
name: Tests

on: [push, pull_request]

jobs:
  test:
    runs-on: ubuntu-latest
    strategy:
      matrix:
        python-version: [3.9, 3.10, 3.11, 3.12]
    
    steps:
    - uses: actions/checkout@v3
    
    - name: 设置 Python ${{ matrix.python-version }}
      uses: actions/setup-python@v4
      with:
        python-version: ${{ matrix.python-version }}
    
    - name: 安装依赖项
      run: |
        python -m pip install --upgrade pip
        pip install -r requirements.txt
        pip install -r requirements-dev.txt        
    
    - name: 运行测试
      run: |
        pytest --cov=myproject --cov-report=xml        
    
    - name: 上传覆盖率
      uses: codecov/codecov-action@v3
      with:
        file: ./coverage.xml

测试无服务器函数

当测试 AWS Lambda 函数或无服务器应用程序时,请考虑与单元测试结合的集成测试策略。我们的指南 使用 Python 和 Terraform 在 AWS 上构建双模式 AWS Lambda 覆盖了无服务器 Python 应用程序的测试方法,包括如何测试 Lambda 处理程序、SQS 消费者和 API Gateway 集成。

测试最佳实践检查表

  • 在编写代码之前或同时编写测试(TDD)
  • 保持测试快速(< 1 秒/测试)
  • 使测试独立和隔离
  • 使用描述性的测试名称
  • 遵循 AAA 模式(Arrange-Act-Assert)
  • 测试边界情况和错误条件
  • 模拟外部依赖
  • 目标是 80%+ 的代码覆盖率
  • 在 CI/CD 管道中运行测试
  • 定期审查和重构测试
  • 文档化复杂的测试场景
  • 使用固件进行常见设置
  • 参数化相似的测试
  • 保持测试简单和可读

常见测试模式

测试类

class TestUser:
    @pytest.fixture
    def user(self):
        return User("john_doe", "john@example.com")
    
    def test_username(self, user):
        assert user.username == "john_doe"
    
    def test_email(self, user):
        assert user.email == "john@example.com"
    
    def test_full_name(self, user):
        user.first_name = "John"
        user.last_name = "Doe"
        assert user.full_name() == "John Doe"

使用临时文件进行测试

import pytest
from pathlib import Path

@pytest.fixture
def temp_file(tmp_path):
    """创建临时文件"""
    file_path = tmp_path / "test_file.txt"
    file_path.write_text("test content")
    return file_path

def test_read_file(temp_file):
    content = temp_file.read_text()
    assert content == "test content"

测试文件生成

当测试生成文件的代码(如 PDF、图像或文档)时,使用临时目录并验证文件属性:

@pytest.fixture
def temp_output_dir(tmp_path):
    """提供临时输出目录"""
    output_dir = tmp_path / "output"
    output_dir.mkdir()
    return output_dir

def test_pdf_generation(temp_output_dir):
    pdf_path = temp_output_dir / "output.pdf"
    generate_pdf(pdf_path, content="Test")
    
    assert pdf_path.exists()
    assert pdf_path.stat().st_size > 0

对于生成 PDF 的全面示例,请参阅我们的指南 在 Python 中生成 PDF,其中涵盖了各种 PDF 库的测试策略。

使用 Monkeypatch 进行测试

def test_environment_variable(monkeypatch):
    monkeypatch.setenv("API_KEY", "test_key_123")
    assert os.getenv("API_KEY") == "test_key_123"

def test_module_attribute(monkeypatch):
    monkeypatch.setattr("module.CONSTANT", 42)
    assert module.CONSTANT == 42

有用的链接和资源

相关资源

Python 基础和最佳实践

测试特定的 Python 用例

无服务器和云测试


单元测试是 Python 开发者的一项重要技能。无论你选择 unittest 还是 pytest,关键是编写一致的测试,保持它们的可维护性,并将它们集成到你的开发流程中。从简单的测试开始,逐步采用高级技术如模拟和固件,并使用覆盖率工具来识别未测试的代码。