diff --git a/cfa_rt_postprocessing/main_functions.py b/cfa_rt_postprocessing/main_functions.py index 1066119..790253d 100644 --- a/cfa_rt_postprocessing/main_functions.py +++ b/cfa_rt_postprocessing/main_functions.py @@ -405,6 +405,16 @@ def merge_and_render_anomaly( except Exception as e: console.log(f"Failed to upload the flu anomaly report: {e}") + # === Calculate the categories for the samples ===================================== + console.status("Calculating the categories from the samples") + p_growing = calculate_categories(final_samples) + + # Save it to file as parquet, and as CSV + p_growing_pq_file = internal_review / "p_growing.parquet" + p_growing_csv_file = internal_review / "p_growing.csv" + p_growing.write_parquet(p_growing_pq_file) + p_growing.write_csv(p_growing_csv_file) + # === Clean up ===================================================================== conn.close() console.log(f"Cleaning up {root} folder") @@ -437,6 +447,55 @@ def render_report( ) +def calculate_categories(samples_file: Path) -> pl.DataFrame: + """ + Takes in the path to the samples parquet file, calculates the five categories for + each geo_value, disease, and reference_date. + + Returns a DataFrame with the columns: + - geo_value + - disease + - reference_date + - p_growing + - category + + The samples file is fairly large, so use duckdb to help things go faster and more + efficiently. + """ + conn = duckdb.connect() + conn.sql( + f""" + -- First create a 'view' of the samples. Don't use a table, because that would read + -- it all into RAM. ATM, a single samples file is about 1.7GB, but attempting to load + -- it into RAM crashes my 32GB machine. + CREATE OR REPLACE VIEW samples AS + SELECT reference_date, geo_value, disease, "value" as Rt + FROM '{str(samples_file.absolute())}' + WHERE "_variable" = 'Rt'; + + -- Calculate the p_growing for each geo_value, disease, and reference_date + CREATE OR REPLACE TABLE p_growing AS SELECT + geo_value, disease, reference_date, + AVG(IF(Rt > 1, 1, 0)) AS p_growing, + CASE + WHEN (p_growing > 0.9) AND (p_growing <= 1.0) THEN 'Growing' + WHEN (p_growing > 0.75) AND (p_growing <= 0.9) THEN 'Likely Growing' + WHEN (p_growing > 0.25) AND (p_growing <= 0.75) THEN 'Not Changing' + WHEN (p_growing > 0.10) AND (p_growing <= 0.25) THEN 'Likely Declining' + WHEN (p_growing >= 0.0) AND (p_growing <= 0.10) THEN 'Declining' + END AS five_cat_p_growing + FROM samples + GROUP BY ALL + ORDER BY ALL; + """ + ) + + p_growing = conn.sql("SELECT * FROM p_growing").pl() + conn.close() + + return p_growing + + if __name__ == "__main__": # Some sample inputs for testing. Need to move something like this to an actual test args = {