import pytest import sqlite3 import pandas as pd import tempfile import os from unittest.mock import patch, MagicMock from src.utils.database_populator import populate_database class TestPopulateDatabase: @staticmethod def single_wave_data_creates_correct_table(): test_data = {1: pd.DataFrame({"col1": [1, 2], "col2": ["a", "b"]})} with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file: db_path = tmp_file.name try: populate_database(test_data, db_path) conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") tables = cursor.fetchall() assert ("wave1",) in tables cursor.execute("SELECT * FROM wave1") rows = cursor.fetchall() assert len(rows) == 2 assert rows[0] == (1, "a") assert rows[1] == (2, "b") conn.close() finally: os.unlink(db_path) @staticmethod def multiple_waves_create_separate_tables(): test_data = { 1: pd.DataFrame({"wave1_col": [1, 2]}), 2: pd.DataFrame({"wave2_col": [3, 4]}), 3: pd.DataFrame({"wave3_col": [5, 6]}), } with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file: db_path = tmp_file.name try: populate_database(test_data, db_path) conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") tables = [table[0] for table in cursor.fetchall()] assert "wave1" in tables assert "wave2" in tables assert "wave3" in tables assert len(tables) == 3 conn.close() finally: os.unlink(db_path) @staticmethod def empty_dataframe_creates_table_with_no_rows(): test_data = {1: pd.DataFrame({"empty_col": []})} with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file: db_path = tmp_file.name try: populate_database(test_data, db_path) conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute("SELECT COUNT(*) FROM wave1") row_count = cursor.fetchone()[0] assert row_count == 0 cursor.execute("PRAGMA table_info(wave1)") columns = cursor.fetchall() assert len(columns) == 1 assert columns[0][1] == "empty_col" conn.close() finally: os.unlink(db_path) @staticmethod def empty_dictionary_creates_no_tables(): test_data = {} with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file: db_path = tmp_file.name try: populate_database(test_data, db_path) conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute("SELECT name FROM sqlite_master WHERE type='table'") tables = cursor.fetchall() assert len(tables) == 0 conn.close() finally: os.unlink(db_path) @staticmethod def existing_database_tables_are_replaced(): test_data = {1: pd.DataFrame({"col": [1, 2]})} with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file: db_path = tmp_file.name try: conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute("CREATE TABLE wave1 (old_col INTEGER)") cursor.execute("INSERT INTO wave1 VALUES (999)") conn.commit() conn.close() populate_database(test_data, db_path) conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute("SELECT * FROM wave1") rows = cursor.fetchall() assert len(rows) == 2 assert rows[0] == (1,) assert rows[1] == (2,) cursor.execute("PRAGMA table_info(wave1)") columns = cursor.fetchall() assert len(columns) == 1 assert columns[0][1] == "col" conn.close() finally: os.unlink(db_path) @staticmethod def database_uses_default_path_when_not_specified(): test_data = {1: pd.DataFrame({"col": [1]})} default_path = "results/study_results.sqlite" with patch("sqlite3.connect") as mock_connect: mock_connection = MagicMock() mock_connect.return_value = mock_connection populate_database(test_data) mock_connect.assert_called_once_with(default_path) mock_connection.close.assert_called_once() @staticmethod def dataframe_with_various_data_types_preserved(): test_data = { 1: pd.DataFrame( { "int_col": [1, 2], "float_col": [1.5, 2.7], "str_col": ["text1", "text2"], "bool_col": [True, False], } ) } with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file: db_path = tmp_file.name try: populate_database(test_data, db_path) conn = sqlite3.connect(db_path) df_result = pd.read_sql_query("SELECT * FROM wave1", conn) assert len(df_result) == 2 assert list(df_result.columns) == [ "int_col", "float_col", "str_col", "bool_col", ] assert df_result["int_col"].iloc[0] == 1 assert df_result["str_col"].iloc[1] == "text2" conn.close() finally: os.unlink(db_path) @patch("sqlite3.connect") def connection_closed_even_when_exception_occurs(self, mock_connect): mock_connection = MagicMock() mock_connect.return_value = mock_connection mock_connection.__enter__ = MagicMock(return_value=mock_connection) mock_connection.__exit__ = MagicMock(return_value=False) test_dataframe = pd.DataFrame({"col": [1, 2]}) test_dataframe.to_sql = MagicMock(side_effect=Exception("SQL Error")) test_data = {1: test_dataframe} with pytest.raises(Exception, match="SQL Error"): populate_database(test_data, "test.db") mock_connection.close.assert_called_once() @staticmethod def wave_numbers_create_correct_table_names(): test_data = { 10: pd.DataFrame({"col": [1]}), 99: pd.DataFrame({"col": [2]}), 1: pd.DataFrame({"col": [3]}), } with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file: db_path = tmp_file.name try: populate_database(test_data, db_path) conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute( "SELECT name FROM sqlite_master WHERE type='table' ORDER BY name" ) tables = [table[0] for table in cursor.fetchall()] expected_tables = ["wave1", "wave10", "wave99"] assert tables == expected_tables conn.close() finally: os.unlink(db_path) @staticmethod def dataframe_index_not_stored_in_database(): df_with_custom_index = pd.DataFrame({"col": [1, 2]}) df_with_custom_index.index = ["row1", "row2"] test_data = {1: df_with_custom_index} with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file: db_path = tmp_file.name try: populate_database(test_data, db_path) conn = sqlite3.connect(db_path) cursor = conn.cursor() cursor.execute("PRAGMA table_info(wave1)") columns = [column[1] for column in cursor.fetchall()] assert "col" in columns assert "index" not in columns assert len(columns) == 1 conn.close() finally: os.unlink(db_path)