2025/02/20

テクノロジー

APIのモックを用いたユニットテストとE2Eテストについて

この記事の目次

    今回はAPIのモックを用いたユニットテストとE2Eテストについて実際のコードを使いながら紹介しようと思います。

    モックを用いたユニットテストの概要

    モックとは

    まずモックとは何かについて説明します。

    テストしたい関数が他のクラスに依存していることはよくあると思います。
    例えば、SNSのとあるユーザーの投稿を取得するAPIのサービスクラスは投稿が公開か未公開か確認のために投稿のレポジトリクラスに依存し、またユーザーの存在を確認するためにユーザーのレポジトリクラスにも依存しています。
    この状況下において、モックを使わずにUTを実装すると、ユーザーのレポジトリクラスのUTが失敗した場合、サービスクラスのUTも失敗しているということになります。そのため、原因特定に時間がかかります。

    モックを使って実装すると、依存しているクラスや関数が想定通りの挙動をするように設定できるため、ユーザーのレポジトリクラスのUTが失敗した場合でも、サービスクラスのUTは成功します。そのため、瞬時にユーザーのレポジトリクラスのみでバグが生じていることがわかります。
    つまり、モックとはUTの責任範囲を明確にし、 UTを実装しやすくする存在です。

    依存注入

    ただ、注意しなくてはいけないのは UTの対象関数の内、UTが制御できるのは対象関数の呼び方のみであるということです。
    つまり、モックを使ってUTを制御するためには、モックするクラスを関数またはクラスの引数に設定する必要があります。

    そのため、対象関数またはクラスの引数はクラスを注入できるように実装する必要があります。
    これを依存注入と呼びます。

    モックを用いたUTの実例

    モックや依存注入について説明が終わったため、実際のコードを使って説明したいと思います。
    今回はSNSのとあるユーザーの投稿を取得するAPIとそのUTコードを実装しました。

    コントローラー

    コントローラーのコードは以下のようになっています。
    get_post_info関数の引数にサービスクラスを依存注入しています。

    /controllers/get_post_controller.py<code>from dependency_injector.wiring import Provide, inject
    from fastapi import APIRouter, Depends
    
    from app.api_schemas.get_post_schema import (GetPostRequest, GetPostResponse,
                                                 GetPostSchema)
    from app.core.container import Container
    from app.services.get_post_service import GetPostService
    
    router = APIRouter()
    
    
    @router.get("/posts/{post_id}")
    @inject
    def get_post_info(
        get_post_request: GetPostRequest = Depends(),
        service: GetPostService = Depends(Provide[Container.get_post_service]),
    ) -> GetPostResponse:
        if post := service.get_post_info(
            get_post_request.post_id, get_post_request.user_id
        ):
            return GetPostResponse(
                result=True,
                post=GetPostSchema(title=post.title, description=post.description),
            )
        return GetPostResponse(result=False, post=None)
    </code>

    これに対するUTコードは以下のようになっています。
    mock_get_post_service関数でサービスのモックを作成し、各テストケースで利用しています。

    コントローラーのget_post_info関数の中で使うサービスのメソッドの返り値をモックで設定することで関数内の条件分岐を制御しています。

    /tests/controllers/test_get_post_controller.pyfrom unittest.mock import MagicMock
    
    import pytest
    import requests
    
    from app.api_schemas.get_post_schema import (GetPostRequest, GetPostResponse,
                                                 GetPostSchema)
    from app.controllers.get_post_controller import get_post_info
    from app.models.post import PostTable
    from app.models.user import UserTable
    from tests.base_test import BaseTest
    
    @pytest.fixture()
    def mock_get_post_service():
        return MagicMock()
    
    def test_get_post_succeeds(mock_get_post_service):
        mock_get_post_service.get_post_info.return_value = (
            PostTable.test_public_post_by_user1_data()
        )
    
        request = GetPostRequest(post_id=1, user_id=1)
        response = get_post_info(get_post_request=request, service=mock_get_post_service)
        assert response == GetPostResponse(
            result=True,
            post=GetPostSchema(
                title=PostTable.test_public_post_by_user1_data().title,
                description=PostTable.test_public_post_by_user1_data().description,
            ),
        )
    
    def test_get_post_fails(mock_get_post_service):
        mock_get_post_service.get_post_info.return_value = None
    
        request = GetPostRequest(post_id=1, user_id=1)
        response = get_post_info(get_post_request=request, service=mock_get_post_service)
        assert response == GetPostResponse(result=False, post=None)

    サービス

    サービスのコードは以下のようになっています。

    サービスはクラスにまとめているため、クラスの__init__関数で依存するユーザーレポジトリクラスと投稿レポジトリクラスを注入しています。
    これによってクラス内の関数のインスタンスから依存先を利用できるようになっています。

    /services/get_post_service.py<code>from app.models.post import PostTable
    from app.repositories.post_repository import PostRepository
    from app.repositories.user_repository import UserRepository
    
    
    class GetPostService:
        def __init__(
            self, post_repository: PostRepository, user_repository: UserRepository
        ):
            self.post_repository = post_repository
            self.user_repository = user_repository
    
        def get_post_info(self, post_id, user_id) -> PostTable:
            if not self.user_repository.get_user(user_id):
                return None
            post = self.post_repository.get_post(post_id)
            if not post or not self.__is_visible(post, user_id):
                return None
            return post
    
        def __is_visible(self, post: PostTable, user_id) -> bool:
            if post.user_id == user_id:
                return True
            elif not post.is_private:
                return True
            else:
                return False
    </code>

    これに対するUTコードは以下のようになっています。

    get_post_service関数で依存するユーザーレポジトリクラスと投稿レポジトリクラスをモックしています。
    サービスのget_post_info関数の中で使うレポジトリのメソッドの返り値をモックで設定することで関数内の条件分岐を制御しています。

    tests/services/test_get_post_service.py<code>from datetime import datetime
    from unittest.mock import MagicMock
    
    import pytest
    
    from app.models.post import PostTable
    from app.models.user import UserTable
    from app.services.get_post_service import GetPostService
    
    
    @pytest.fixture()
    def get_post_service():
        return GetPostService(post_repository=MagicMock(), user_repository=MagicMock())
    
    
    def test_non_existing_user(get_post_service):
        non_existing_user_id = 1
        get_post_service.user_repository.get_user.return_value = None
        assert get_post_service.get_post_info(1, non_existing_user_id) == None
    
    
    def test_non_existing_post(get_post_service):
        non_existing_post_id = 1
        get_post_service.user_repository.get_user.return_value = (
            UserTable.test_not_login_user1_data()
        )
        get_post_service.post_repository.get_post.return_value = None
        assert (
            get_post_service.get_post_info(
                non_existing_post_id, UserTable.test_not_login_user1_data().id
            )
            == None
        )
    
    
    def test_get_private_post_from_non_author(get_post_service):
        get_post_service.user_repository.get_user.return_value = (
            UserTable.test_not_login_user1_data()
        )
        get_post_service.post_repository.get_post.return_value = (
            PostTable.test_private_post_by_user1_data()
        )
        assert (
            get_post_service.get_post_info(
                UserTable.test_not_login_user1_data().id,
                PostTable.test_private_post_by_user1_data().id,
            )
            == None
        )
    
    
    def test_get_private_post_from_author(get_post_service):
        get_post_service.user_repository.get_user.return_value = (
            UserTable.test_not_login_user1_data()
        )
        get_post_service.post_repository.get_post.return_value = (
            PostTable.test_private_post_by_user1_data()
        )
        assert (
            get_post_service.get_post_info(
                PostTable.test_private_post_by_user1_data().id,
                UserTable.test_not_login_user1_data().id,
            )
            == get_post_service.post_repository.get_post.return_value
        )
    
    
    def test_get_public_post_from_non_author(get_post_service):
        get_post_service.user_repository.get_user.return_value = (
            UserTable.test_login_user2_data()
        )
        get_post_service.post_repository.get_post.return_value = (
            PostTable.test_public_post_by_user1_data()
        )
        assert (
            get_post_service.get_post_info(
                PostTable.test_public_post_by_user1_data().id,
                UserTable.test_login_user2_data().id,
            )
            == get_post_service.post_repository.get_post.return_value
        )
    </code>

    レポジトリ

    レポジトリのコードは以下のようになっています。

    レポジトリはクラスにまとめているため、クラスの__init__関数で依存するDBを注入しています。
    DBを注入することで開発環境のDBとは別のDBにデータを入れることができるため、開発環境のDBに影響を与えずに済みます。
    これによってクラス内の関数のインスタンスから依存先を利用できるようになっています。

    /repositories/post_repository.py<code>from pydantic import BaseModel
    from sqlalchemy.orm import Session
    
    from app.models.post import PostTable
    from app.models.user import UserTable
    
    
    class PostRepository:
        def __init__(self, db: Session):
            self.db = db
    
        def get_post(self, post_id) -> PostTable:
            return (
                self.db.query(PostTable)
                .join(UserTable, UserTable.id == PostTable.user_id)
                .filter(PostTable.id == post_id)
                .first()
            )
    </code>

    これに対するUTコードは以下のようになっています。

    レポジトリはAPIの最奥層であるため、何もモックせずに実際にテスト用DBにデータを入れた上でUTを書いています。
    テスト用DBとアプリ用DBの切り替えはbase_test.pyで行っていますが、ここでは省略します。

    /tests/repositories/test_post_repository.py<code>from app.helpers.helper import get_datetime_now_db_format
    from app.models.post import PostTable
    from app.models.user import UserTable
    from app.repositories.post_repository import PostRepository
    from tests.base_test import BaseTest
    
    
    class TestPostRepository(BaseTest):
        @classmethod
        def _initialize_repository(cls):
            cls.post_repository = PostRepository(cls.db)
    
        @classmethod
        def _insert_data(cls):
            cls.db.add_all(
                [
                    PostTable.test_public_post_by_user1_data(),
                    UserTable.test_not_login_user1_data(),
                ]
            )
            cls.db.commit()
    
        @classmethod
        def test_get_existing_post(cls):
            response = cls.post_repository.get_post(
                PostTable.test_public_post_by_user1_data().id
            )
            assert response.id == PostTable.test_public_post_by_user1_data().id
            assert response.title == PostTable.test_public_post_by_user1_data().title
            assert (
                response.description
                == PostTable.test_public_post_by_user1_data().description
            )
            assert response.user_id == PostTable.test_public_post_by_user1_data().user_id
            assert (
                response.is_private == PostTable.test_public_post_by_user1_data().is_private
            )
    
        @classmethod
        def test_get_non_existing_post(cls):
            non_existing_post_id = 2
            response = cls.post_repository.get_post(non_existing_post_id)
            assert response == None
    </code>

    E2Eテスト

    E2Eテストとはシステム全体をテストするものです。
    E2Eテストを実行することで関数間の値の受け渡しが正常であることを担保し、UTのみではカバーできないところをカバーし、バグが発生する可能性を下げることができます。

    このレポジトリはAPIしか作成していないため、フロントエンドの挙動までは確認しません。
    ここでは特定のパスにリクエストが来てからレスポンスが返されるまでの一連の動作を確認します。
    そのため、ここでも実際にデータを入れます。

    /tests/controllers/test_get_post_controller.py<code>from unittest.mock import MagicMock
    
    import pytest
    import requests
    
    from app.api_schemas.get_post_schema import (GetPostRequest, GetPostResponse,
                                                 GetPostSchema)
    from app.controllers.get_post_controller import get_post_info
    from app.models.post import PostTable
    from app.models.user import UserTable
    from tests.base_test import BaseTest
    
    class TestGetPostController(BaseTest):
        @classmethod
        def _insert_data(cls):
            cls.db.add_all(
                [
                    UserTable.test_not_login_user1_data(),
                    PostTable.test_public_post_by_user1_data(),
                ]
            )
            cls.db.commit()
    
        @classmethod
        def test_e2e(cls):
            post_id = str(PostTable.test_public_post_by_user1_data().id)
            user_id = str(UserTable.test_not_login_user1_data().id)
            response = cls.client.get(
                "/posts/" + post_id, params={"post_id": post_id, "user_id": user_id}
            )
            assert response.status_code == 200
            assert response.json() == {
                "result": True,
                "post": {
                    "title": PostTable.test_public_post_by_user1_data().title,
                    "description": PostTable.test_public_post_by_user1_data().description,
                },
            }
    </code>

    最後に

    APIのモックを用いたユニットテストは、自分が書いたコードが仕様を正しく反映していることを迅速に確認するための非常に効率的な手法です。
    一定期間が経過して仕様を忘れてしまった場合でも、仕様がコードとして明確に表現されているため、再確認が容易になります。

    また、E2Eテストはユニットテストだけでは見落としがちな、全体の動作を確認するのに非常に有効です。
    これにより、実際の動作環境での問題を早期に発見し、修正することができます。

    ぜひ、これらのテスト手法を取り入れて、効率的にAPI開発を進めていただければと思います。
    今後も、テストの重要性を意識しながら、より良いソフトウェアを作り上げていきましょう。

    ※本記事は2025年02月時点の情報です。

    著者:マイナビエンジニアブログ編集部