Source code for nnfwtbn.tests.test_cut


from nnfwtbn.cut import Cut
import unittest
import pandas as pd


[docs]class CutTestCase(unittest.TestCase): """ Test the implementation of the cut class. """
[docs] def setUp(self): """ Create a default dataframe for testing. """ self.df = pd.DataFrame([[2010, 3.9], [2011, 2.8], [2012, 4.7], [2013, 5.6], [2014, 7.5], [2015, 3.4], [2016, 2.3], [2017, 4.2]], columns=["year", "sale"])
[docs] def test_default_cut(self): """ Make sure that the default cut accepts very event in the dataframe. """ default = Cut() selected = default.idx_array(self.df) self.assertTrue((selected).all())
[docs] def test_init_with_lambda(self): """ Check that creating a cut with a lambda expression applies the uses the lambda to filter the dataframe. """ high_sale = Cut(lambda df: df.sale > 4) selected = high_sale.idx_array(self.df) self.assertEqual(list(selected), [False, False, True, True, True, False, False, True]) self.assertEqual(list(self.df[selected].year), [2012, 2013, 2014, 2017])
[docs] def test_and(self): """ Check that two cuts can be joined logically. """ high_sale = Cut(lambda df: df.sale > 4) old = Cut(lambda df: df.year < 2015) combined = high_sale & old selected = combined.idx_array(self.df) self.assertEqual(list(selected), [False, False, True, True, True, False, False, False]) self.assertEqual(list(self.df[selected].year), [2012, 2013, 2014])
[docs] def test_and_lambda(self): """ Check that a cut and a lambda can be joined logically. """ high_sale = Cut(lambda df: df.sale > 4) combined = high_sale & (lambda df: df.year < 2015) selected = combined.idx_array(self.df) self.assertEqual(list(selected), [False, False, True, True, True, False, False, False]) self.assertEqual(list(self.df[selected].year), [2012, 2013, 2014])
[docs] def test_rand_lambda(self): """ Check that a cut and a lambda (from the left) can be joined logically. """ high_sale = Cut(lambda df: df.sale > 4) combined = (lambda df: df.year < 2015) & high_sale selected = combined.idx_array(self.df) self.assertEqual(list(selected), [False, False, True, True, True, False, False, False]) self.assertEqual(list(self.df[selected].year), [2012, 2013, 2014])
[docs] def test_and_bool(self): """ Check that a cut and a boolean can be joined logically. """ high_sale = Cut(lambda df: df.sale > 4) combined = high_sale & 1 selected = combined.idx_array(self.df) self.assertEqual(list(selected), [False, False, True, True, True, False, False, True]) self.assertEqual(list(self.df[selected].year), [2012, 2013, 2014, 2017]) combined = high_sale & 0 selected = combined.idx_array(self.df) self.assertFalse((selected).any())
[docs] def test_or(self): """ Check that two cuts can be joined logically. """ high_sale = Cut(lambda df: df.sale > 4) old = Cut(lambda df: df.year < 2015) combined = high_sale | old selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, True, True, True, False, False, True]) self.assertEqual(list(self.df[selected].year), [2010, 2011, 2012, 2013, 2014, 2017])
[docs] def test_or_lambda(self): """ Check that a cut and a lambda can be joined logically. """ high_sale = Cut(lambda df: df.sale > 4) combined = high_sale | (lambda df: df.year < 2015) selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, True, True, True, False, False, True]) self.assertEqual(list(self.df[selected].year), [2010, 2011, 2012, 2013, 2014, 2017])
[docs] def test_ror_lambda(self): """ Check that a cut and a lambda (from the left) can be joined logically. """ high_sale = Cut(lambda df: df.sale > 4) combined = (lambda df: df.year < 2015) | high_sale selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, True, True, True, False, False, True]) self.assertEqual(list(self.df[selected].year), [2010, 2011, 2012, 2013, 2014, 2017])
[docs] def test_or_bool(self): """ Check that a cut and a boolean can be joined logically. """ high_sale = Cut(lambda df: df.sale > 4) combined = high_sale | 1 selected = combined.idx_array(self.df) self.assertTrue(selected.all()) combined = high_sale | 0 selected = combined.idx_array(self.df) self.assertEqual(list(selected), [False, False, True, True, True, False, False, True]) self.assertEqual(list(self.df[selected].year), [2012, 2013, 2014, 2017])
[docs] def test_xor(self): """ Check that two cuts can be joined logically. """ high_sale = Cut(lambda df: df.sale > 4) old = Cut(lambda df: df.year < 2015) combined = high_sale ^ old selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, False, False, False, False, False, True]) self.assertEqual(list(self.df[selected].year), [2010, 2011, 2017])
[docs] def test_xor_lambda(self): """ Check that a cut and a lambda can be joined logically. """ high_sale = Cut(lambda df: df.sale > 4) combined = high_sale ^ (lambda df: df.year < 2015) selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, False, False, False, False, False, True]) self.assertEqual(list(self.df[selected].year), [2010, 2011, 2017])
[docs] def test_rxor_lambda(self): """ Check that a cut and a lambda (from the left) can be joined logically. """ high_sale = Cut(lambda df: df.sale > 4) combined = (lambda df: df.year < 2015) ^ high_sale selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, False, False, False, False, False, True]) self.assertEqual(list(self.df[selected].year), [2010, 2011, 2017])
[docs] def test_xor_bool(self): """ Check that a cut and a boolean can be joined logically. """ high_sale = Cut(lambda df: df.sale > 4) combined = high_sale ^ 1 selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, False, False, False, True, True, False]) self.assertEqual(list(self.df[selected].year), [2010, 2011, 2015, 2016]) combined = high_sale ^ 0 selected = combined.idx_array(self.df) self.assertEqual(list(selected), [False, False, True, True, True, False, False, True]) self.assertEqual(list(self.df[selected].year), [2012, 2013, 2014, 2017])
[docs] def test_not(self): """ Check that a cut and a boolean can be joined logically. """ high_sale = Cut(lambda df: df.sale > 4) combined = ~high_sale selected = combined.idx_array(self.df) self.assertEqual(list(selected), [True, True, False, False, False, True, True, False]) self.assertEqual(list(self.df[selected].year), [2010, 2011, 2015, 2016])
[docs] def test_label(self): """ Check that names specified during construction are available via the 'name' attribute. """ high_sale = Cut(lambda df: df.sale > 4, label="High sales volume") self.assertEqual(high_sale.label, "High sales volume")
[docs] def test_init_cut(self): """ Check that a cut can be passed to the constructor. """ high_sale = Cut(lambda df: df.sale > 4) high_sale2 = Cut(high_sale) self.assertEqual(len(high_sale2(self.df)), 4) self.assertEqual(len(high_sale2.idx_array(self.df)), 8)
[docs] def test_init_cut_name_inherit(self): """ Check that the name of a cut passed to the constructor is inherited. """ high_sale = Cut(lambda df: df.sale > 4, label="High sales volume") high_sale2 = Cut(high_sale) self.assertEqual(high_sale2.label, "High sales volume")
[docs] def test_init_cut_name_inherit_precedence(self): """ Check that the name argument has precedence over the given cut. """ high_sale = Cut(lambda df: df.sale > 4, label="High sales volume") high_sale2 = Cut(high_sale, label="Other label") self.assertEqual(high_sale2.label, "Other label")
[docs] def test_call(self): """ Check that calling the cut returns a dataframe containing the events rather than returning an index array. """ high_sale = Cut(lambda df: df.sale > 4) high_sale_years = high_sale(self.df) self.assertEqual(list(high_sale_years.sale), [4.7, 5.6, 7.5, 4.2]) self.assertEqual(list(high_sale_years.year), [2012, 2013, 2014, 2017])
[docs] def test_call_empty_input(self): """ Check that calling the object returns an empty dataframe if the input dataframe is empty. """ df = pd.DataFrame({"year": [], "sale": []}) high_sale = Cut(lambda df: df.sale > 4) self.assertEqual(list(high_sale(df).year), [])
[docs] def test_call_no_match(self): """ Check that calling the object returns an empty dataframe when no event matches the selection. """ high_sale = Cut(lambda df: df.sale > 10) self.assertEqual(list(high_sale(self.df).year), [])