diff --git a/tableprint/printer.py b/tableprint/printer.py index 1ff16c0..240f7be 100644 --- a/tableprint/printer.py +++ b/tableprint/printer.py @@ -21,15 +21,15 @@ from six import string_types from .style import LineStyle, STYLES from .utils import ansi_len, format_line -__all__ = ('table', 'header', 'row', 'hrule', 'top', 'bottom', 'banner', 'dataframe') +__all__ = ('table', 'header', 'row', 'hrule', 'top', 'bottom', 'banner', 'dataframe', 'TableContext') STYLE = 'round' WIDTH = 11 FMT = '5g' -class Table: - def __init__(self, headers, width=WIDTH, style=STYLE, add_hr=True): +class TableContext: + def __init__(self, headers, width=WIDTH, style=STYLE, add_hr=True, out=sys.stdout): """Context manager for table printing Parameters @@ -37,7 +37,7 @@ class Table: headers : array_like A list of N strings consisting of the header of each of the N columns - width : int, optional + width : int or array_like, optional The width of each column in the table (Default: 11) style : string or tuple, optional @@ -48,22 +48,27 @@ class Table: Usage ----- - >>> with Table("ABC") as t: + >>> with TableContext("ABC") as t: for k in range(10): t.row(np.random.randn(3)) """ - self.headers = header(headers, width=width, style=style, add_hr=add_hr) - self.bottom = bottom(len(headers), width=width, style=style) + self.out = out + self.config = {'width': width, 'style': style} + self.headers = header(headers, add_hr=add_hr, **self.config) + self.bottom = bottom(len(headers), **self.config) def __call__(self, data): - print(row(data), flush=True) + self.out.write(row(data, **self.config) + '\n') + self.out.flush() def __enter__(self): - print(self.headers, flush=True) + self.out.write(self.headers + '\n') + self.out.flush() return self def __exit__(self, *exc): - print(self.bottom, flush=True) + self.out.write(self.bottom + '\n') + self.out.flush() def table(data, headers=None, format_spec=FMT, width=WIDTH, style=STYLE, out=sys.stdout): diff --git a/tests/test_io.py b/tests/test_io.py index ac7b5fd..2c2ceb9 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -1,12 +1,21 @@ # -*- coding: utf-8 -*- from __future__ import unicode_literals -from tableprint import table, banner, dataframe, hrule +from tableprint import table, banner, dataframe, hrule, TableContext from io import StringIO import numpy as np -def test_table(): +def test_context(): + """Tests the table context manager""" + output = StringIO() + with TableContext('ABC', style='round', width=5, out=output) as t: + t([1, 2, 3]) + t([4, 5, 6]) + assert output.getvalue() == '╭─────┬─────┬─────╮\n│ A │ B │ C │\n├─────┼─────┼─────┤\n│ 1│ 2│ 3│\n│ 4│ 5│ 6│\n╰─────┴─────┴─────╯\n' + +def test_table(): + """Tests the table function""" output = StringIO() table([[1, 2, 3], [4, 5, 6]], 'ABC', style='round', width=5, out=output) assert output.getvalue() == '╭─────┬─────┬─────╮\n│ A │ B │ C │\n├─────┼─────┼─────┤\n│ 1│ 2│ 3│\n│ 4│ 5│ 6│\n╰─────┴─────┴─────╯\n' @@ -17,7 +26,7 @@ def test_table(): def test_frame(): - + """Tests the dataframe function""" # mock of a pandas DataFrame class DataFrame: def __init__(self, data, headers): @@ -35,7 +44,7 @@ def test_frame(): def test_banner(): - + """Tests the banner function""" output = StringIO() banner('hello world', style='clean', width=11, out=output) assert output.getvalue() == ' ─────────── \n hello world \n ─────────── \n' @@ -46,7 +55,7 @@ def test_banner(): def test_hrule(): - + """Tests the hrule function""" output = hrule(1, width=11) assert len(output) == 11 assert '───────────'