267 lines
8.1 KiB
Python
267 lines
8.1 KiB
Python
import pytest
|
|
import sqlite3
|
|
import pandas as pd
|
|
import tempfile
|
|
import os
|
|
from unittest.mock import patch, MagicMock
|
|
from src.utils.database_populator import populate_database
|
|
|
|
|
|
class TestPopulateDatabase:
|
|
@staticmethod
|
|
def single_wave_data_creates_correct_table():
|
|
test_data = {1: pd.DataFrame({"col1": [1, 2], "col2": ["a", "b"]})}
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file:
|
|
db_path = tmp_file.name
|
|
|
|
try:
|
|
populate_database(test_data, db_path)
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
tables = cursor.fetchall()
|
|
|
|
assert ("wave1",) in tables
|
|
|
|
cursor.execute("SELECT * FROM wave1")
|
|
rows = cursor.fetchall()
|
|
assert len(rows) == 2
|
|
assert rows[0] == (1, "a")
|
|
assert rows[1] == (2, "b")
|
|
|
|
conn.close()
|
|
finally:
|
|
os.unlink(db_path)
|
|
|
|
@staticmethod
|
|
def multiple_waves_create_separate_tables():
|
|
test_data = {
|
|
1: pd.DataFrame({"wave1_col": [1, 2]}),
|
|
2: pd.DataFrame({"wave2_col": [3, 4]}),
|
|
3: pd.DataFrame({"wave3_col": [5, 6]}),
|
|
}
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file:
|
|
db_path = tmp_file.name
|
|
|
|
try:
|
|
populate_database(test_data, db_path)
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
tables = [table[0] for table in cursor.fetchall()]
|
|
|
|
assert "wave1" in tables
|
|
assert "wave2" in tables
|
|
assert "wave3" in tables
|
|
assert len(tables) == 3
|
|
|
|
conn.close()
|
|
finally:
|
|
os.unlink(db_path)
|
|
|
|
@staticmethod
|
|
def empty_dataframe_creates_table_with_no_rows():
|
|
test_data = {1: pd.DataFrame({"empty_col": []})}
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file:
|
|
db_path = tmp_file.name
|
|
|
|
try:
|
|
populate_database(test_data, db_path)
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT COUNT(*) FROM wave1")
|
|
row_count = cursor.fetchone()[0]
|
|
|
|
assert row_count == 0
|
|
|
|
cursor.execute("PRAGMA table_info(wave1)")
|
|
columns = cursor.fetchall()
|
|
assert len(columns) == 1
|
|
assert columns[0][1] == "empty_col"
|
|
|
|
conn.close()
|
|
finally:
|
|
os.unlink(db_path)
|
|
|
|
@staticmethod
|
|
def empty_dictionary_creates_no_tables():
|
|
test_data = {}
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file:
|
|
db_path = tmp_file.name
|
|
|
|
try:
|
|
populate_database(test_data, db_path)
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT name FROM sqlite_master WHERE type='table'")
|
|
tables = cursor.fetchall()
|
|
|
|
assert len(tables) == 0
|
|
|
|
conn.close()
|
|
finally:
|
|
os.unlink(db_path)
|
|
|
|
@staticmethod
|
|
def existing_database_tables_are_replaced():
|
|
test_data = {1: pd.DataFrame({"col": [1, 2]})}
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file:
|
|
db_path = tmp_file.name
|
|
|
|
try:
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("CREATE TABLE wave1 (old_col INTEGER)")
|
|
cursor.execute("INSERT INTO wave1 VALUES (999)")
|
|
conn.commit()
|
|
conn.close()
|
|
|
|
populate_database(test_data, db_path)
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("SELECT * FROM wave1")
|
|
rows = cursor.fetchall()
|
|
|
|
assert len(rows) == 2
|
|
assert rows[0] == (1,)
|
|
assert rows[1] == (2,)
|
|
|
|
cursor.execute("PRAGMA table_info(wave1)")
|
|
columns = cursor.fetchall()
|
|
assert len(columns) == 1
|
|
assert columns[0][1] == "col"
|
|
|
|
conn.close()
|
|
finally:
|
|
os.unlink(db_path)
|
|
|
|
@staticmethod
|
|
def database_uses_default_path_when_not_specified():
|
|
test_data = {1: pd.DataFrame({"col": [1]})}
|
|
default_path = "results/study_results.sqlite"
|
|
|
|
with patch("sqlite3.connect") as mock_connect:
|
|
mock_connection = MagicMock()
|
|
mock_connect.return_value = mock_connection
|
|
|
|
populate_database(test_data)
|
|
|
|
mock_connect.assert_called_once_with(default_path)
|
|
mock_connection.close.assert_called_once()
|
|
|
|
@staticmethod
|
|
def dataframe_with_various_data_types_preserved():
|
|
test_data = {
|
|
1: pd.DataFrame(
|
|
{
|
|
"int_col": [1, 2],
|
|
"float_col": [1.5, 2.7],
|
|
"str_col": ["text1", "text2"],
|
|
"bool_col": [True, False],
|
|
}
|
|
)
|
|
}
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file:
|
|
db_path = tmp_file.name
|
|
|
|
try:
|
|
populate_database(test_data, db_path)
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
df_result = pd.read_sql_query("SELECT * FROM wave1", conn)
|
|
|
|
assert len(df_result) == 2
|
|
assert list(df_result.columns) == [
|
|
"int_col",
|
|
"float_col",
|
|
"str_col",
|
|
"bool_col",
|
|
]
|
|
assert df_result["int_col"].iloc[0] == 1
|
|
assert df_result["str_col"].iloc[1] == "text2"
|
|
|
|
conn.close()
|
|
finally:
|
|
os.unlink(db_path)
|
|
|
|
@patch("sqlite3.connect")
|
|
def connection_closed_even_when_exception_occurs(self, mock_connect):
|
|
mock_connection = MagicMock()
|
|
mock_connect.return_value = mock_connection
|
|
mock_connection.__enter__ = MagicMock(return_value=mock_connection)
|
|
mock_connection.__exit__ = MagicMock(return_value=False)
|
|
|
|
test_dataframe = pd.DataFrame({"col": [1, 2]})
|
|
test_dataframe.to_sql = MagicMock(side_effect=Exception("SQL Error"))
|
|
|
|
test_data = {1: test_dataframe}
|
|
|
|
with pytest.raises(Exception, match="SQL Error"):
|
|
populate_database(test_data, "test.db")
|
|
|
|
mock_connection.close.assert_called_once()
|
|
|
|
@staticmethod
|
|
def wave_numbers_create_correct_table_names():
|
|
test_data = {
|
|
10: pd.DataFrame({"col": [1]}),
|
|
99: pd.DataFrame({"col": [2]}),
|
|
1: pd.DataFrame({"col": [3]}),
|
|
}
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file:
|
|
db_path = tmp_file.name
|
|
|
|
try:
|
|
populate_database(test_data, db_path)
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute(
|
|
"SELECT name FROM sqlite_master WHERE type='table' ORDER BY name"
|
|
)
|
|
tables = [table[0] for table in cursor.fetchall()]
|
|
|
|
expected_tables = ["wave1", "wave10", "wave99"]
|
|
assert tables == expected_tables
|
|
|
|
conn.close()
|
|
finally:
|
|
os.unlink(db_path)
|
|
|
|
@staticmethod
|
|
def dataframe_index_not_stored_in_database():
|
|
df_with_custom_index = pd.DataFrame({"col": [1, 2]})
|
|
df_with_custom_index.index = ["row1", "row2"]
|
|
test_data = {1: df_with_custom_index}
|
|
|
|
with tempfile.NamedTemporaryFile(suffix=".sqlite", delete=False) as tmp_file:
|
|
db_path = tmp_file.name
|
|
|
|
try:
|
|
populate_database(test_data, db_path)
|
|
|
|
conn = sqlite3.connect(db_path)
|
|
cursor = conn.cursor()
|
|
cursor.execute("PRAGMA table_info(wave1)")
|
|
columns = [column[1] for column in cursor.fetchall()]
|
|
|
|
assert "col" in columns
|
|
assert "index" not in columns
|
|
assert len(columns) == 1
|
|
|
|
conn.close()
|
|
finally:
|
|
os.unlink(db_path)
|