"""
This module exports two classes: 'RowTable' and 'ColTable'.  A 'RowTable' is a
table that is filled row by row, a 'ColTable' is filled col by
col. After all values have been entered into the table, you can get an
external representation of the table in three formats: LaTeX,
HTML and CSV. HTML and CSV are not fully supported, some operations make
only sense for LaTeX.


Example::

    tab = RowTable()
    tab.addColSeparator(0)
    tab.addRow(1,2,3)
    tab.addRowSeparator()
    tab.addRow(4,5,6)
    latexOutput = tab.getAsLatex()


The code above would produce a table like that:
<pre>
  1 | 2  3
  --+-------
  4 | 5  6
</pre>


For use with LaTeX you should put something like this into your document:
<pre>
\usepackage{dcolumn}
% the 1st arg is the decimal separator used in the .tex file,
% the 2nd arg is the decimal separator that should be used in the .dvi
%   file
% the 3rd specifies the number of digits to the left and to the
%   right of the decimal separator
% Remember: the .tex file may also contain grouping separators
\newcolumntype{d}[0]{D{.}{.}{3}}
</pre>

@author: Stefan Wehr
@url: http://www.stefanwehr.de
@license: LGPL, http://www.gnu.org/licenses/lgpl.txt
@exports: ColTable, RowTable
@version: 2005-03-15
"""


from types import *

import locale

def formatNumber(f, ndec=3):
    if type(f) not in [FloatType, IntType, LongType]: return f
    tr = 9999
    if f > tr:
        format = '%.5E'
    elif type(f) == FloatType:
        format = '%%.%df' % ndec
    else:
        format = '%d'
    s = locale.format(format, f, grouping=1)
    if type(f) != FloatType or f > tr:
        return s
    while ndec > 1:
        if s[-1] == '0':
            s = s[:-1]
        ndec -= 1
    return s

TableError = 'TableError'

def addRowToCols(row, cols):
    """
    Adds a 'row' to the 'cols' of a table. A 'row' is a list of values. The
    parameter 'cols' is a list containing the cols of the table.
    """
    if len(row) != len(cols):
        raise TableError, 'Row has length %d, but there are %d columns' % \
              (len(row), len(cols))
    for (r,c) in zip(row, cols):
        c.append(r)

class AbstractTable:

    def __init__(self, caption=None):
        self.cols = []
        self.headers = []
        self.rowSeparators = []
        self.colSeparators = []
        self.colFormats = {}
        self.colTextFormats = {}
        self.caption = caption
        self.ndec = 3
        self.withColSeparators = 0
        self.defaultColFormat = 'd'
        # records how much cols the cell at (row,col) occupies
        self.colspan = {}  
        self.longtable = 0
        # useless for a column based table. only needed for RowTable
        self.lastHeaderRow = -1  
        self.label = ''
        self.fontsize = None

    def isLongtable(self, bool):
        """
        Specifies if the latex package 'longtable' should be used for
        this table.
        """
        self.longtable = bool
        
    def getColspan(self, rowIndex, colIndex):
        return self.colspan.get( (rowIndex, colIndex), 1)

    def setColspan(self, rowIndex, colIndex, value):
        assert value >= 1
        self.colspan[(rowIndex, colIndex)] = value

    def getAsCsv(self, colsep=';'):
        """
        Returns a string with the table in CSV (comma separated values)
        format. Some information might be lost when exporting the table
        to the CSV format.
        """
        s = ';'.join(self.headers) + '\n'
        s += self.formatBodyForCsv(colsep)
        return s
    
    def formatBodyForCsv(self, colsep):
        s = ''
        n = len(self.cols)
        for i in range(len(self.cols[0])):
            offset = 0
            for j in range(n):
                if j+offset >= n:
                    # remove the trailing colsep
                    s = s[:-1]
                    break
                rowIndex = i
                colIndex = j+offset
                content = self.cols[colIndex][rowIndex]
                colspan = self.getColspan(rowIndex, colIndex)
                s += str(content) + colspan*colsep
                offset += colspan-1
            s += '\n'
        return s
    
    def getAsHtml(self):
        """
        
        Returns a string with the table in HTML format. Some
        information might be lost when exporting the table to HTML
        format.

        """
        str = '<html><head>\n'
        if self.caption:
            str += '<title>%s</title>\n' % self.caption
        str += '<body>\n<table border="1">\n'
        str += self.formatHeaderForHtml()
        str += self.formatBodyForHtml()
        str += '</table>\n</body>\n</html>'
        return str

    def formatHeaderForHtml(self):
        if not self.headers: return ''
        s = '<tr>'
        for h in self.headers:
            s += '<td>%s</td>' % h
        s += '</tr>\n'
        return s

    def formatBodyForHtml(self):
        s = ''
        n = len(self.cols)
        for i in range(len(self.cols[0])):
            offset = 0
            s += '<tr>\n'
            for j in range(n):
                if j+offset >= n: break
                content,colspan = self.formatCellContentForHtml(i, j+offset) 
                s += '<td colspan="%d">%s</td>' % (colspan, content)
                offset += colspan-1
            s += '</tr>\n'
        return s

    def formatCellContentForHtml(self, rowIndex, colIndex):
        content = self.cols[colIndex][rowIndex]
        colspan = self.getColspan(rowIndex, colIndex)
        if type(content) in [FloatType, IntType, LongType]:
            s = str(formatNumber(content, ndec=self.ndec))
        else:
            s = content
        if not s: s = '&nbsp;'
        return (s, colspan)

    
    def getAsLatex(self):
        """
        Returns a string with the table in LaTeX format.
        """
        format = ''
        for i in range(0, len(self.cols)):
            format += self.colFormats.get(i, self.defaultColFormat) + \
                      self.getColSeparatorForLatex(i)
        str = ''
        
        if not self.longtable and self.caption:
            str += '\\begin{table}\n'
        if self.fontsize:
            str += '{\\' + self.fontsize + '\n'
        if self.longtable:
            str += '\\begin{longtable}{' + format + '}\n'
        else:
            str += '\\begin{tabular}{' + format + '}\n'

        header = self.formatHeaderForLatex()
        if header: str += header + '\\hline\n'
        if self.longtable: str += '\endhead'
        str += self.formatBodyForLatex()

        if not self.longtable:
            str += '\\end{tabular}\n'
        if self.caption:
            str += '\\caption{'
            if self.label:
                str += '\\label{%s}' % self.label
            str += self.caption + '}\n'
        if self.longtable:
            str += '\\end{longtable}\n'
        if self.fontsize:
            str += '}\n'
        if not self.longtable and self.caption:
            str += '\\end{table}\n'

        return str


    def getColSeparatorForLatex(self, colIndex):
        if self.withColSeparators or colIndex in self.colSeparators: return '|'
        else: return ''
        
    def formatHeaderForLatex(self):
        colsep = ' & '
        linesep = ' \\\\\n'
        if not self.headers: return ''
        import string
        hs = []
        for i in range(0, len(self.headers)):
            s, colspan = self.formatCellContentForLatex2(self.headers[i], i)
            hs.append(s)
        return string.join(hs, colsep) + linesep

    def formatCellContentForLatex(self, rowIndex, colIndex):
        content = self.cols[colIndex][rowIndex]
        colspan = self.getColspan(rowIndex, colIndex)
        return self.formatCellContentForLatex2(content, colIndex, colspan)

    def formatCellContentForLatex2(self, content, colIndex, colspan=1):
        if type(content) in [FloatType, IntType, LongType]:
            s = str(formatNumber(content, ndec=self.ndec))
            if colspan > 1:
                s = '\\multicolumn{%d}{%s}{%s}' % \
                    (colspan, self.defaultColFormat, s)
        else:
            format = self.colTextFormats.get(colIndex, 'c') + \
                     self.getColSeparatorForLatex(colIndex+colspan-1)
            s = '\\multicolumn{%d}{%s}{%s}' % (colspan, format, content)
        return (s, colspan)
        
    def formatBodyForLatex(self):
        colsep = ' & '
        linesep = ' \\\\\n'
        s = ''
        n = len(self.cols)
        m = len(self.cols[0])
        for i in range(m):
            offset = 0
            for j in range(n):
                if j+offset >= n: break
                content,colspan = self.formatCellContentForLatex(i, j+offset)
                s += content
                offset += colspan-1
                if j+offset < n-1:
                    s += colsep
            if i < m-1:
                s += linesep
            if i in self.rowSeparators:
                if i == m-1: s += linesep
                s += '\\hline\n'
            if i == self.lastHeaderRow and self.longtable:
                s += '\endhead\n'
        return s

class ColTable(AbstractTable):

    def __init__(self, caption=None):
        """
        Creates a new 'ColTable' with an optional 'caption'.
        A 'ColTable' is a table that is constructed by adding cols to the table,
        i.e. the table is built col by col.
        """
        AbstractTable.__init__(self, caption)
        
    def addCol(self, col, header=''):
        """
        Adds a new 'col' to the table. 'header' is the header of the 'col'.
        The first col of the table determines the length of all other cols.
        If the new col is not the first col and its length is the same as the
        the first col's length, a 'TableError' is raised.
        """
        if self.cols and len(col) != len(self.cols[0]):
            msg = 'Illegal length for new column.\n'
            msg += 'New column has length %d, but reference ' \
                   'column has length %d.\n' % (len(col), len(self.cols[0]))
            msg += header + ': ' + str(col)
            raise TableError, msg 
        self.cols.append(col)
        self.headers.append(header)

    def addRowSeparator(self, afterRow):
        """
        Adds a separator after row 'afterRow'. The first row has index 0.
        """
        self.rowSeparators.append(afterRow)

    def addColSeparator(self, afterCol=-1):
        """
        Adds a separator after col 'afterCol'.
        The first col has index 0.       
        If 'afterCol' is negative, a separator is inserted after the most
        recently added col.
        """
        if afterCol < 0:
            if (len(self.cols) == 0):
                raise 'Illegal state: no cols have been added, ' \
                      'no afterCol is given.'
            afterCol = len(self.cols)-1
        self.colSeparators.append(afterCol)

    def setColFormat(self, type, colNo=-1):
        """
        Sets the (latex) format for column at index colNo (first col has
        index 0). If 'colNo' is negative, the format for the most
        recently added col is set. The format only applies to cells
        with numerical content. Use 'setColTextFormat' for setting the
        format of cells with textual content.
        """
        if colNo < 0:
            if (len(self.cols) == 0):
                raise 'Illegal state: no cols have been added, no colNo '\
                      'is given.'
            colNo = len(self.cols)-1
        self.colFormats[colNo] = type

    def setColTextFormat(self, type, colNo=-1):
        """
        Sets the (latex) format for column at index colNo (first col has
        index 0). If 'colNo' is negative, the format for the most
        recently added col is set. The format only applies to cells
        with textual content. Use 'setColTextFormat' for setting the
        format of cells with numerical content.
        """
        if colNo < 0:
            if (len(self.cols) == 0):
                raise 'Illegal state: no cols have been added, no colNo '\
                      'is given.'
            colNo = len(self.cols)-1
        self.colTextFormats[colNo] = type


class RowTable(AbstractTable):

    def __init__(self, caption=None):
        """
        Creates a new 'RowTable' with an optional 'caption'.
        A 'RowTable' is a table that is constructed by adding rows to the
        table, i.e. the table is built row by row.
        """
        AbstractTable.__init__(self, caption)
    
    def addRow(self, row):
        """
        'row' is a list that contains the values for one row. If
        there are too many values in the list (i.e. a row with
        less cells than this has been added before) a exception is
        raised. If there are not enough values the rest is filled with
        the empty string.
        
        A value can occupy several cells of a row. You can specify the
        number of cells a value occupies by inserting the tuple
        (colspan, value) into the row-list. 'colspan' is the number
        of cells 'value' occupies.

        Example::

            addRow(['first', (2, 'hello world!'), 'fourth'])

        inserts a row that is 4 cells wide into the table. The first
        cell contains the value 'first', the second and third cell
        together contain the value 'hello world!' and the fourth cell
        contains the value 'fourth'.
        """
        t = self

        # check for cells that have a colspan value associated with,
        # replace the (colspan, value) tuples with the value and
        # insert dummy values as necessary.  Example: ['first', (2,
        # 'hello world!'), 'forth'] --> ['first', 'hello world!',
        # None, 'forth']       
        rowIndex = 0
        if t.cols: rowIndex = len(t.cols[0])
        tmpRow = []
        colIndex = 0
        for c in row:
            if type(c) == TupleType:
                colspan = c[0]
                content = c[1]
                t.setColspan(rowIndex, colIndex, colspan)
                tmpRow.append(content)
                tmpRow += (colspan-1)*[None]
                colIndex += colspan
            else:
                tmpRow.append(c)
                colIndex += 1
        row = tmpRow
        
        if not t.cols:
            for cell in row:
                t.cols.append([cell])
        else:
            refLen = len(t.cols)
            n = len(row)
            if n > refLen:
                raise 'Row is two long. First row has length %d, this row ' \
                      'has length %d' % (refLen, n)
            elif n < refLen:
                row = row + (refLen-n)*['']
            addRowToCols(row, t.cols)

    def addRowSeparator(self, afterRow=-1):
        """
        Adds a separator after row 'afterRow'. The first row has index 0.    
        If 'afterRow' is negative, a separator is inserted after the most
        recently added row.
        """
        if afterRow < 0:
            if (len(self.cols) == 0):
                raise 'Illegal state: no cols have been added, no afterCol ' \
                      'is given.'
            afterRow = len(self.cols[0])-1
        self.rowSeparators.append(afterRow)

    def addColSeparator(self, afterCol):
        """
        Adds a separator after col 'afterCol'. The first col has index 0.
        """
        self.colSeparators.append(afterCol)

    def setColFormat(self, type, colNo=-1):
        """
        Sets the (latex) format for column at index colNo (first col
        has index 0). The format only applies to cells with numerical
        content. Use 'setColTextFormat' for setting the format of
        cells with textual content.
        """
        self.colFormats[colNo] = type

    def setColTextFormat(self, type, colNo=-1):
        """
        Sets the (latex) format for column at index colNo (first col
        has index 0). The format only applies to cells with textual
        content. Use 'setColTextFormat' for setting the format of
        cells with numerical content.
        """
        self.colTextFormats[colNo] = type
        
    def setLastHeaderRow(self, index=-1):
        """
        Marks the 'index'-th row as the last header row. If 'index < 0',
        the most recently added row is marked as the last header row
        (in case no row has been added, an exception is raised).
        """
        if index < 0:
            if (len(self.cols) == 0):
                raise 'Illegal state: no cols have been added, no index ' \
                      'is given.'
            index = len(self.cols[0])-1
        self.lastHeaderRow = index




def _setupTestTable():
    t = RowTable()
    t.addRow(['Col 1', 'Col 2', 'Col 3'])
    t.setLastHeaderRow()
    t.addRow([1,2,3])
    t.addRow([(2, '12'), 3])
    t.addRow([1, (2, '23')])
    t.addRow([(3, '123')])
    return t

if __name__ == '__main__':
    t = _setupTestTable()
    print t.getAsCsv()
        

