preprocessing/src/composite_processor.py
2025-12-15 13:47:28 +01:00

122 lines
5.1 KiB
Python

import pandas as pd
def process_composites(
dataframe: pd.DataFrame,
composite_scale_specifications: dict,
wave_alphas: dict[str, float | None],
scale_item_counts: dict[str, int] | None = None,
) -> tuple[pd.DataFrame, dict]:
"""Compute composite scale columns based on provided specifications.
Iterates over composite scale definitions and calculates new columns using the specified aggregation method
('mean', 'sum', 'weighted_mean', 'coalesce'). Supports subgroup filtering if specified.
Args:
dataframe (pd.DataFrame): DataFrame containing all computed scales.
composite_scale_specifications (dict): Dictionary with composite scale definitions from the wave config.
wave_alphas (dict[str, float] | None): Existing dictionary of Cronbach's alpha values for scales in the wave.
Returns:
tuple[pd.DataFrame, dict[str, float]]: DataFrame containing the new composite scale columns
and updated alpha values dictionary.
Raises:
ValueError: If required columns are missing, or an unknown method is specified.
NotImplementedError: If the 'categorical' method is requested.
"""
composites: dict = {}
updated_alphas: dict = {}
for (
composite_scale_name,
composite_scale_specification,
) in composite_scale_specifications.items():
scale_columns: list = composite_scale_specification.get("scales", [])
method: str = composite_scale_specification.get("method", "mean")
subgroup: str = composite_scale_specification.get("subgroup", "all")
if len(scale_columns) == 0:
continue
missing: list[str] = [
col for col in scale_columns if col not in dataframe.columns
]
if missing:
raise ValueError(
f"Missing columns for composite {composite_scale_name}: {missing}"
)
mask: pd.Series = pd.Series(True, dataframe.index)
if subgroup and subgroup != "all" and subgroup in dataframe.columns:
mask = dataframe[subgroup].astype(bool)
dataframe_subset: pd.DataFrame = dataframe.loc[mask, scale_columns]
if method == "mean":
composite_scores: pd.Series = dataframe_subset.mean(axis=1)
elif method == "sum":
composite_scores = dataframe_subset.sum(axis=1)
elif method == "weighted_mean":
weights_spec = composite_scale_specification.get("weights")
if weights_spec is not None:
weights = pd.Series(weights_spec, dtype="float64")
weights = weights.reindex(scale_columns)
if weights.isna().any():
missing_weights = weights[weights.isna()].index.tolist()
raise ValueError(
f"Composite {composite_scale_name}: Missing weights for scales {missing_weights}"
)
elif scale_item_counts is not None:
weights = pd.Series(
[scale_item_counts.get(col, 1) for col in scale_columns],
index=scale_columns,
dtype="float64",
)
else:
raise ValueError(
f"Composite {composite_scale_name}: No weights specified and no scale_item_counts provided."
)
weighted_values = dataframe_subset.mul(weights, axis=1)
numerator = weighted_values.sum(axis=1, skipna=True)
denom_weights = dataframe_subset.notna().mul(weights, axis=1)
denominator = denom_weights.sum(axis=1)
composite_scores = numerator / denominator
composite_scores = composite_scores.where(denominator > 0, pd.NA)
elif method == "categorical":
raise NotImplementedError(
"'categorical' method is not supported as a composite aggregation (use 'coalesce')."
)
elif method == "coalesce":
def coalesce_row(row: pd.Series):
present: pd.Series = row.notna()
if present.sum() > 1:
raise ValueError(
f"Composite '{composite_scale_name}': More than one non-missing value in row (participant_id={dataframe.loc[row.name, 'participant_id']}): {row[present].to_dict()}"
)
return row[present].iloc[0] if present.any() else pd.NA
composite_scores = dataframe_subset.apply(coalesce_row, axis=1)
constituent_alphas = [
wave_alphas.get(col)
for col in scale_columns
if wave_alphas and col in wave_alphas and wave_alphas[col] is not None
]
if constituent_alphas:
updated_alphas[composite_scale_name] = constituent_alphas
else:
raise ValueError(f"Unknown composite method: {method}")
result_column: pd.Series = pd.Series(pd.NA, index=dataframe.index)
result_column[mask] = composite_scores
composites[composite_scale_name] = result_column
return pd.DataFrame(composites), updated_alphas