preprocessing/tests/test_database_populator.py
2025-12-15 13:47:28 +01:00

267 lines
8.1 KiB
Python

import pytest
import sqlite3
import pandas as pd
import tempfile
import os
from unittest.mock import patch, MagicMock
from src.utils.database_populator import populate_database
class TestPopulateDatabase:
@staticmethod
def single_wave_data_creates_correct_table():
test_data = {1: pd.DataFrame({"col1": [1, 2], "col2": ["a", "b"]})}
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file:
db_path = tmp_file.name
try:
populate_database(test_data, db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = cursor.fetchall()
assert ("wave1",) in tables
cursor.execute("SELECT * FROM wave1")
rows = cursor.fetchall()
assert len(rows) == 2
assert rows[0] == (1, "a")
assert rows[1] == (2, "b")
conn.close()
finally:
os.unlink(db_path)
@staticmethod
def multiple_waves_create_separate_tables():
test_data = {
1: pd.DataFrame({"wave1_col": [1, 2]}),
2: pd.DataFrame({"wave2_col": [3, 4]}),
3: pd.DataFrame({"wave3_col": [5, 6]}),
}
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file:
db_path = tmp_file.name
try:
populate_database(test_data, db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = [table[0] for table in cursor.fetchall()]
assert "wave1" in tables
assert "wave2" in tables
assert "wave3" in tables
assert len(tables) == 3
conn.close()
finally:
os.unlink(db_path)
@staticmethod
def empty_dataframe_creates_table_with_no_rows():
test_data = {1: pd.DataFrame({"empty_col": []})}
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file:
db_path = tmp_file.name
try:
populate_database(test_data, db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT COUNT(*) FROM wave1")
row_count = cursor.fetchone()[0]
assert row_count == 0
cursor.execute("PRAGMA table_info(wave1)")
columns = cursor.fetchall()
assert len(columns) == 1
assert columns[0][1] == "empty_col"
conn.close()
finally:
os.unlink(db_path)
@staticmethod
def empty_dictionary_creates_no_tables():
test_data = {}
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file:
db_path = tmp_file.name
try:
populate_database(test_data, db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
tables = cursor.fetchall()
assert len(tables) == 0
conn.close()
finally:
os.unlink(db_path)
@staticmethod
def existing_database_tables_are_replaced():
test_data = {1: pd.DataFrame({"col": [1, 2]})}
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file:
db_path = tmp_file.name
try:
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("CREATE TABLE wave1 (old_col INTEGER)")
cursor.execute("INSERT INTO wave1 VALUES (999)")
conn.commit()
conn.close()
populate_database(test_data, db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("SELECT * FROM wave1")
rows = cursor.fetchall()
assert len(rows) == 2
assert rows[0] == (1,)
assert rows[1] == (2,)
cursor.execute("PRAGMA table_info(wave1)")
columns = cursor.fetchall()
assert len(columns) == 1
assert columns[0][1] == "col"
conn.close()
finally:
os.unlink(db_path)
@staticmethod
def database_uses_default_path_when_not_specified():
test_data = {1: pd.DataFrame({"col": [1]})}
default_path = "results/study_results.sqlite"
with patch("sqlite3.connect") as mock_connect:
mock_connection = MagicMock()
mock_connect.return_value = mock_connection
populate_database(test_data)
mock_connect.assert_called_once_with(default_path)
mock_connection.close.assert_called_once()
@staticmethod
def dataframe_with_various_data_types_preserved():
test_data = {
1: pd.DataFrame(
{
"int_col": [1, 2],
"float_col": [1.5, 2.7],
"str_col": ["text1", "text2"],
"bool_col": [True, False],
}
)
}
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file:
db_path = tmp_file.name
try:
populate_database(test_data, db_path)
conn = sqlite3.connect(db_path)
df_result = pd.read_sql_query("SELECT * FROM wave1", conn)
assert len(df_result) == 2
assert list(df_result.columns) == [
"int_col",
"float_col",
"str_col",
"bool_col",
]
assert df_result["int_col"].iloc[0] == 1
assert df_result["str_col"].iloc[1] == "text2"
conn.close()
finally:
os.unlink(db_path)
@patch("sqlite3.connect")
def connection_closed_even_when_exception_occurs(self, mock_connect):
mock_connection = MagicMock()
mock_connect.return_value = mock_connection
mock_connection.__enter__ = MagicMock(return_value=mock_connection)
mock_connection.__exit__ = MagicMock(return_value=False)
test_dataframe = pd.DataFrame({"col": [1, 2]})
test_dataframe.to_sql = MagicMock(side_effect=Exception("SQL Error"))
test_data = {1: test_dataframe}
with pytest.raises(Exception, match="SQL Error"):
populate_database(test_data, "test.db")
mock_connection.close.assert_called_once()
@staticmethod
def wave_numbers_create_correct_table_names():
test_data = {
10: pd.DataFrame({"col": [1]}),
99: pd.DataFrame({"col": [2]}),
1: pd.DataFrame({"col": [3]}),
}
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file:
db_path = tmp_file.name
try:
populate_database(test_data, db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute(
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
)
tables = [table[0] for table in cursor.fetchall()]
expected_tables = ["wave1", "wave10", "wave99"]
assert tables == expected_tables
conn.close()
finally:
os.unlink(db_path)
@staticmethod
def dataframe_index_not_stored_in_database():
df_with_custom_index = pd.DataFrame({"col": [1, 2]})
df_with_custom_index.index = ["row1", "row2"]
test_data = {1: df_with_custom_index}
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file:
db_path = tmp_file.name
try:
populate_database(test_data, db_path)
conn = sqlite3.connect(db_path)
cursor = conn.cursor()
cursor.execute("PRAGMA table_info(wave1)")
columns = [column[1] for column in cursor.fetchall()]
assert "col" in columns
assert "index" not in columns
assert len(columns) == 1
conn.close()
finally:
os.unlink(db_path)