Source code for chart_me.charting_assembly_strategy.bivariate

"""Default implementation for bivariate variable charting

This module contains all the logic to go from Metadata to actual charts
See supporting documentation that discusses the rules engine.

    Typical usage example:

    charts = assemble_bivariate_charts(df, [col1, col2], infered_data_types)
"""
# Standard library imports
from typing import List, Union

# Third party imports
import altair as alt
import pandas as pd

# chart_me imports
from chart_me.datatype_infer_strategy import ChartMeDataTypeMetaType, InferedDataTypes
from chart_me.pandas_util import pd_group_me, pd_truncate_date


[docs]def assemble_bivariate_charts( df: pd.DataFrame, cols: List[str], infered_data_types: InferedDataTypes, **kwargs ) -> List[Union[alt.Chart, alt.HConcatChart]]: """Delegated Function to Manage Bivariate Use Cases Args: df: dataframe cols: a list of two columns infered_data_types: An instance of InferedDataTypes object Returns: List of altair charts or compounds charts Raises: ValueError if called with len of cols != 2 """ if len(cols) == 1: raise ValueError("This module only supports two valid columns") col_name1 = cols[0] col_name2 = cols[1] merged_meta_types = { ChartMeDataTypeMetaType.CATEGORICAL_HIGH_CARDINALITY: ChartMeDataTypeMetaType.KEY, # noqa: E501 ChartMeDataTypeMetaType.BOOLEAN: ChartMeDataTypeMetaType.CATEGORICAL_LOW_CARDINALITY, # noqa: E501 } col_MT1 = merged_meta_types.get( infered_data_types.chart_me_data_types_meta[col_name1], infered_data_types.chart_me_data_types_meta[col_name1], ) col_MT2 = merged_meta_types.get( infered_data_types.chart_me_data_types_meta[col_name2], infered_data_types.chart_me_data_types_meta[col_name2], ) preagg_fl = ( infered_data_types.preaggregated ) # doesn't impact behavior for univariate/bivariate return_charts = [] if {col_MT1, col_MT2} == { ChartMeDataTypeMetaType.QUANTITATIVE, ChartMeDataTypeMetaType.QUANTITATIVE, }: return_charts.append(build_scatter_plot(df, col_name1, col_name2)) elif {col_MT1, col_MT2} == { ChartMeDataTypeMetaType.QUANTITATIVE, ChartMeDataTypeMetaType.KEY, }: col_name_x, col_name_y = ( [col_name2, col_name1] if col_MT1 == ChartMeDataTypeMetaType.KEY else [col_name1, col_name2] ) return_charts.append(build_hbar_value(df.head(), col_name_x, col_name_y)) elif {col_MT1, col_MT2} == { ChartMeDataTypeMetaType.QUANTITATIVE, ChartMeDataTypeMetaType.CATEGORICAL_LOW_CARDINALITY, }: col_name_n, col_name_q = ( [col_name1, col_name2] if col_MT1 == ChartMeDataTypeMetaType.CATEGORICAL_LOW_CARDINALITY else [col_name2, col_name1] ) if not preagg_fl: return_charts.append(build_facet_histogram(df, col_name_n, col_name_q)) agg_dict = {f"{col_name_q}": ["count", "min", "max", "mean", "median"]} df = pd_group_me(df, col_name_n, agg_dict, make_long_form=True) # return_charts.insert(0, df) # TODO how to store manipulated dataframes to pass back to user return_charts.append( build_facet_hbars( df, col_name_facet="measures", col_name_y=col_name_n, col_name_x=col_name_q, ) ) elif {col_MT1, col_MT2} == { ChartMeDataTypeMetaType.QUANTITATIVE, ChartMeDataTypeMetaType.TEMPORAL, }: col_name_t, col_name_q = ( [col_name1, col_name2] if col_MT1 == ChartMeDataTypeMetaType.TEMPORAL else [col_name2, col_name1] ) col_name_t_m_y = f"{col_name_t}_m_y" df[col_name_t_m_y] = pd_truncate_date(df, col_name_t) agg_dict = {col_name_q: ["count", "min", "max", "mean", "median"]} df = pd_group_me( df, col_name_t_m_y, agg_dict, is_temporal=True, make_long_form=True ) return_charts.append( build_hconcat_temp_charts(df, col_name_t_m_y, col_name_q, "measures") ) elif {col_MT1, col_MT2} == { ChartMeDataTypeMetaType.KEY, ChartMeDataTypeMetaType.KEY, }: df["quantity"] = 1 df["label"] = df[col_name1].astype(str) + "-" + df[col_name2].astype(str) return_charts.append( build_hbar_value(df.head(n=20), col_name_x="quantity", col_name_y="label") ) elif {col_MT1, col_MT2} == { ChartMeDataTypeMetaType.KEY, ChartMeDataTypeMetaType.CATEGORICAL_LOW_CARDINALITY, }: col_name_n, col_name_k = ( [col_name1, col_name2] if col_MT1 == ChartMeDataTypeMetaType.CATEGORICAL_LOW_CARDINALITY else [col_name2, col_name1] ) agg_dict = {col_name_k: ["count", "nunique"]} df = pd_group_me(df, col_name_n, agg_dict, make_long_form=True) return_charts.append( build_facet_hbars( df, col_name_facet="measures", col_name_y=col_name_n, col_name_x=col_name_k, ) ) elif {col_MT1, col_MT2} == { ChartMeDataTypeMetaType.KEY, ChartMeDataTypeMetaType.TEMPORAL, }: col_name_t, col_name_k = ( [col_name1, col_name2] if col_MT1 == ChartMeDataTypeMetaType.TEMPORAL else [col_name2, col_name1] ) col_name_t_m_y = f"{col_name_t}_m_y" df[col_name_t_m_y] = pd_truncate_date(df, col_name_t) agg_dict = {col_name_k: ["count", "nunique"]} df = pd_group_me( df, col_name_t_m_y, agg_dict, is_temporal=True, make_long_form=True ) return_charts.append( build_hconcat_temp_charts(df, col_name_t_m_y, col_name_k, "measures") ) elif {col_MT1, col_MT2} == { ChartMeDataTypeMetaType.CATEGORICAL_LOW_CARDINALITY, ChartMeDataTypeMetaType.CATEGORICAL_LOW_CARDINALITY, }: df["_counts_"] = 1 df = pd_group_me( df, cols=[col_name1, col_name2], agg_dict={"_counts_": ["sum"]}, make_long_form=True, ) return_charts.append(build_heatmap(df, col_name1, col_name2, "_counts_")) elif {col_MT1, col_MT2} == { ChartMeDataTypeMetaType.CATEGORICAL_LOW_CARDINALITY, ChartMeDataTypeMetaType.TEMPORAL, }: col_name_t, col_name_n = ( [col_name1, col_name2] if col_MT1 == ChartMeDataTypeMetaType.TEMPORAL else [col_name2, col_name1] ) col_name_t_m_y = f"{col_name_t}_m_y" df[col_name_t_m_y] = pd_truncate_date(df, col_name_t) df["_counts_"] = 1 df = pd_group_me( df, cols=[col_name_t_m_y, col_name_n], agg_dict={"_counts_": ["sum"]}, make_long_form=True, ) return_charts.append( build_hconcat_temp_lc_charts(df, col_name_t_m_y, col_name_n, "_counts_") ) elif {col_MT1, col_MT2} == { ChartMeDataTypeMetaType.TEMPORAL, ChartMeDataTypeMetaType.TEMPORAL, }: col_name_t_m_y1, col_name_t_m_y2 = [f"{col_name1}_m_y", f"{col_name2}_m_y"] df[col_name_t_m_y1] = pd_truncate_date(df, col_name1) df[col_name_t_m_y2] = pd_truncate_date(df, col_name2) df["__day_diff__"] = (df[col_name1] - df[col_name2]).dt.days # chart_me imports from chart_me.charting_assembly_strategy.univariate import build_histogram return_charts.append(build_histogram(df, "__day_diff__")) df["_counts_"] = 1 df = pd_group_me( df, cols=[col_name_t_m_y1, col_name_t_m_y2], agg_dict={"_counts_": ["sum"]}, make_long_form=True, ) return_charts.append( build_heatmap(df, col_name_t_m_y1, col_name_t_m_y2, "_counts_") ) else: raise NotImplementedError( f"unknown handling of metatype-{str(col_MT1)}-{str(col_MT2)}" ) return return_charts
[docs]def build_scatter_plot(df: pd.DataFrame, col_name1: str, col_name2: str) -> alt.Chart: """An implementation of scatter plot""" chart = alt.Chart(df).mark_point().encode(x=f"{col_name1}:Q", y=f"{col_name2}:Q") return chart
[docs]def build_hbar_value(df: pd.DataFrame, col_name_x: str, col_name_y: str) -> alt.Chart: """An implementation of horizontal bar chart""" chart = alt.Chart(df).mark_bar().encode(x=f"{col_name_x}:Q", y=f"{col_name_y}:O") return chart
[docs]def build_facet_histogram( df: pd.DataFrame, col_name_facet: str, col_name_hist_q: str ) -> alt.Chart: """An implementation of histogram faceted by nominal variable""" chart = ( alt.Chart(df) .mark_bar() .encode( alt.X(f"{col_name_hist_q}:Q", bin=True), y="count()", facet=alt.Facet(f"{col_name_facet}:O", columns=4), ) ) return chart
[docs]def build_facet_hbars( df: pd.DataFrame, col_name_facet: str, col_name_y: str, col_name_x: str ) -> alt.Chart: """An implementation of horizontal bar graph faceted by nominal variable""" chart = ( alt.Chart(df) .mark_bar() .encode( x=f"{col_name_x}:Q", y=f"{col_name_y}:O", facet=alt.Facet(f"{col_name_facet}:O", columns=2), ) .resolve_scale(x="independent") ) return chart
[docs]def build_hconcat_temp_charts( df: pd.DataFrame, col_name_y_m, col_name_q, col_name_measure: str = "measures" ) -> alt.HConcatChart: """An implementation of horizontal bar graph faceted by nominal variable WARNING: Function assumes processing through pd_group_me to separate count from other aggregations """ df_cnt = df[df[col_name_measure] == "count"] df_msr = df[df[col_name_measure] != "count"] chart1 = ( alt.Chart(df_cnt).mark_bar().encode(x=f"{col_name_y_m}:O", y=f"{col_name_q}:Q") ) chart2 = ( alt.Chart(df_msr) .mark_line() .encode(x=f"{col_name_y_m}:O", y=f"{col_name_q}:Q", color="measures") ) return alt.hconcat(chart1, chart2)
[docs]def build_heatmap( df: pd.DataFrame, col_name_x: str, col_name_y: str, col_name_q: str ) -> alt.Chart: """An implementation of heatmap""" chart = ( alt.Chart(df) .mark_rect() .encode(x=f"{col_name_x}:O", y=f"{col_name_y}:O", color=f"{col_name_q}:Q") ) return chart
[docs]def build_hconcat_temp_lc_charts( df: pd.DataFrame, col_name_t_y_m: str, col_name_n: str, col_name_q: str ) -> alt.HConcatChart: """An implementation that returns two charts to trend nominal and relative values chart 1-> bar trend chart chart 2-> stacked bar chart """ chart1 = ( alt.Chart(df) .mark_bar() .encode(x=f"{col_name_t_y_m}:O", y=f"{col_name_q}:Q", color=col_name_n) ) chart2 = ( alt.Chart(df) .mark_bar() .encode( x=f"{col_name_t_y_m}:O", y=alt.Y(f"{col_name_q}:Q", stack="normalize"), color=col_name_n, ) ) return alt.hconcat(chart1, chart2)