import pandas as pd
import seaborn as sns
from nnfwtbn import Variable, Process, Cut, confusion_matrix, HistogramFactory
from nnfwtbn import toydata
df = toydata.get()
p_sig = Process(r"Signal", range=(1, 1))
p_ztt = Process(r"$Z\rightarrow\tau\tau$", range=(0, 0))
c_low = Cut(lambda d: d.m_jj < 350, label="Low $m^{jj}$")
c_mid = Cut(lambda d: (d.m_jj >= 350) & (d.m_jj < 600), label="Mid $m^{jj}$")
c_high = Cut(lambda d: d.m_jj > 600, label="High $m^{jj}$")
confusion_matrix(df, [p_sig, p_ztt], [c_low, c_mid, c_high], info=False,
y_label="Region", x_label="Truth Signal", annot=True, weight="weight")
None
confusion_matrix(df, [p_sig, p_ztt], [c_low, c_mid, c_high], normalize_rows=True, info=False,
y_label="Region", x_label="Truth Signal", annot=True, weight="weight")
None