Source code for codetransformer.transformers.literals

from collections import OrderedDict
from decimal import Decimal
from itertools import islice
import sys
from textwrap import dedent

from .. import instructions
from ..core import CodeTransformer
from ..patterns import pattern,  matchany, var
from ..utils.instance import instance


IN_COMPREHENSION = 'in_comprehension'


class overloaded_dicts(CodeTransformer):
    """Transformer that allows us to overload dictionary literals.

    This acts by creating an empty map and then inserting every
    key value pair in order.

    The code that is generated will turn something like::

        {k_0: v_0, k_1: v_1, ..., k_n: v_n}

    into::

        _tmp = astype()
        _tmp[k_0] = v_0
        _tmp[k_1] = v_1
        ...
        _tmp[k_n] = v_n
        _tmp  # leaves the map on the stack.

    Parameters
    ----------
    astype : callable
        The constructor for the type to create.

    Examples
    --------
    >>> from collections import OrderedDict
    >>> ordereddict_literals = overloaded_dicts(OrderedDict)
    >>> @ordereddict_literals
    ... def f():
    ...     return {'a': 1, 'b': 2, 'c': 3}
    ...
    >>> f()
    OrderedDict([('a', 1), ('b', 2), ('c', 3)])
    """
    def __init__(self, astype):
        super().__init__()
        self.astype = astype

    @pattern(instructions.BUILD_MAP, matchany[var], instructions.MAP_ADD)
    def _start_comprehension(self, instr, *instrs):
        yield instructions.LOAD_CONST(self.astype).steal(instr)
        # TOS  = self.astype

        yield instructions.CALL_FUNCTION(0)
        # TOS  = m = self.astype()

        yield instructions.STORE_FAST('__map__')

        *body, map_add = instrs
        yield from self.patterndispatcher(body)
        # TOS  = k
        # TOS1 = v

        yield instructions.LOAD_FAST('__map__').steal(map_add)
        # TOS  = __map__
        # TOS1 = k
        # TOS2 = v

        yield instructions.ROT_TWO()
        # TOS  = k
        # TOS1 = __map__
        # TOS2 = v

        yield instructions.STORE_SUBSCR()
        self.begin(IN_COMPREHENSION)

    @pattern(instructions.RETURN_VALUE, startcodes=(IN_COMPREHENSION,))
    def _return_value(self, instr):
        yield instructions.LOAD_FAST('__map__').steal(instr)
        # TOS  = __map__

        yield instr

    if sys.version_info[:2] <= (3, 4):
        # Python 3.4

        @pattern(instructions.BUILD_MAP)
        def _build_map(self, instr):
            yield instructions.LOAD_CONST(self.astype).steal(instr)
            # TOS  = self.astype

            yield instructions.CALL_FUNCTION(0)
            # TOS  = m = self.astype()

            yield from (instructions.DUP_TOP(),) * instr.arg
            # TOS  = m
            # ...
            # TOS[instr.arg] = m

        @pattern(instructions.STORE_MAP)
        def _store_map(self, instr):
            # TOS  = k
            # TOS1 = v
            # TOS2 = m
            # TOS3 = m

            yield instructions.ROT_THREE().steal(instr)
            # TOS  = v
            # TOS1 = m
            # TOS2 = k
            # TOS3 = m

            yield instructions.ROT_THREE()
            # TOS  = m
            # TOS1 = k
            # TOS2 = v
            # TOS3 = m

            yield instructions.ROT_TWO()
            # TOS  = k
            # TOS1 = m
            # TOS2 = v
            # TOS3 = m

            yield instructions.STORE_SUBSCR()
            # TOS  = m

    else:
        # Python 3.5 and beyond!

        def _construct_map(self, key_value_pairs):
            mapping = self.astype()
            for key, value in zip(key_value_pairs[::2], key_value_pairs[1::2]):
                mapping[key] = value
            return mapping

        @pattern(instructions.BUILD_MAP)
        def _build_map(self, instr):
            # TOS      = vn
            # TOS1     = kn
            # ...
            # TOSN     = v0
            # TOSN + 1 = k0
            # Construct a tuple of (k0, v0, k1, v1, ..., kn, vn) for
            # each of the key: value pairs in the dictionary.
            yield instructions.BUILD_TUPLE(instr.arg * 2).steal(instr)
            # TOS  = (k0, v0, k1, v1, ..., kn, vn)

            yield instructions.LOAD_CONST(self._construct_map)
            # TOS  = self._construct_map
            # TOS1 = (k0, v0, k1, v1, ..., kn, vn)

            yield instructions.ROT_TWO()
            # TOS  = (k0, v0, k1, v1, ..., kn, vn)
            # TOS1 = self._construct_map

            yield instructions.CALL_FUNCTION(1)

    if sys.version_info >= (3, 6):
        def _construct_const_map(self, values, keys):
            mapping = self.astype()
            for key, value in zip(keys, values):
                mapping[key] = value
            return mapping

        @pattern(instructions.LOAD_CONST, instructions.BUILD_CONST_KEY_MAP)
        def _build_const_map(self, keys, instr):
            yield instructions.BUILD_TUPLE(len(keys.arg)).steal(keys)
            # TOS  = (v0, v1, ..., vn)

            yield keys
            # TOS  = (k0, k1, ..., kn)
            # TOS1 = (v0, v1, ..., vn)

            yield instructions.LOAD_CONST(self._construct_const_map)
            # TOS  = self._construct_const_map
            # TOS1 = (k0, k1, ..., kn)
            # TOS2 = (v0, v1, ..., vn)

            yield instructions.ROT_THREE()
            # TOS  = (k0, k1, ..., kn)
            # TOS1 = (v0, v1, ..., vn)
            # TOS2 = self._construct_const_map

            yield instructions.CALL_FUNCTION(2)


ordereddict_literals = overloaded_dicts(OrderedDict)


def _format_constant_docstring(type_):
    return dedent(
        """
        Transformer that applies a callable to each {type_} constant in the
        transformed code object.

        Parameters
        ----------
        xform : callable
            A callable to be applied to {type_} literals.

        See Also
        --------
        codetransformer.transformers.literals.overloaded_strs
        """
    ).format(type_=type_.__name__)


class _ConstantTransformerBase(CodeTransformer):

    def __init__(self, xform):
        super().__init__()
        self.xform = xform

    def transform_consts(self, consts):
        # This is all one expression.
        return super().transform_consts(
            tuple(
                frozenset(self.transform_consts(tuple(const)))
                if isinstance(const, frozenset)
                else self.transform_consts(const)
                if isinstance(const, tuple)
                else self.xform(const)
                if isinstance(const, self._type)
                else const
                for const in consts
            )
        )


def overloaded_constants(type_, __doc__=None):
    """A factory for transformers that apply functions to literals.

    Parameters
    ----------
    type_ : type
        The type to overload.
    __doc__ : str, optional
        Docstring for the generated transformer.

    Returns
    -------
    transformer : subclass of CodeTransformer
        A new code transformer class that will overload the provided
        literal types.
    """
    typename = type_.__name__
    if typename.endswith('x'):
        typename += 'es'
    elif not typename.endswith('s'):
        typename += 's'

    if __doc__ is None:
        __doc__ = _format_constant_docstring(type_)

    return type(
        "overloaded_" + typename,
        (_ConstantTransformerBase,), {
            '_type': type_,
            '__doc__': __doc__,
        },
    )


overloaded_strs = overloaded_constants(
    str,
    __doc__=dedent(
        """
        A transformer that overloads string literals.

        Rewrites all constants of the form::

            "some string"

        as::

            xform("some string")

        Parameters
        ----------
        xform : callable
            Function to call on all string literals in the transformer target.

        Examples
        --------
        >>> @overloaded_strs(lambda x: "ayy lmao ")
        ... def prepend_foo(s):
        ...     return "foo" + s
        ...
        >>> prepend_foo("bar")
        'ayy lmao bar'
        """
    )
)
overloaded_bytes = overloaded_constants(bytes)
overloaded_floats = overloaded_constants(float)
overloaded_ints = overloaded_constants(int)
overloaded_complexes = overloaded_constants(complex)

haskell_strs = overloaded_strs(tuple)
bytearray_literals = overloaded_bytes(bytearray)
decimal_literals = overloaded_floats(Decimal)


def _start_comprehension(self, *instrs):
    self.begin(IN_COMPREHENSION)
    yield from self.patterndispatcher(instrs)


def _return_value(self, instr):
    # TOS  = collection

    yield instructions.LOAD_CONST(self.xform).steal(instr)
    # TOS  = self.xform
    # TOS1 = collection

    yield instructions.ROT_TWO()
    # TOS  = collection
    # TOS1 = self.xform

    yield instructions.CALL_FUNCTION(1)
    # TOS  = self.xform(collection)

    yield instr


# Added as a method for overloaded_build
def _build(self, instr):
    yield instr
    # TOS  = new_list

    yield instructions.LOAD_CONST(self.xform)
    # TOS  = astype
    # TOS1 = new_list

    yield instructions.ROT_TWO()
    # TOS  = new_list
    # TOS1 = astype

    yield instructions.CALL_FUNCTION(1)
    # TOS  = astype(new_list)


def overloaded_build(type_, add_name=None):
    """Factory for constant transformers that apply to a given
    build instruction.

    Parameters
    ----------
    type_ : type
        The object type to overload the construction of. This must be one of
        "buildable" types, or types with a "BUILD_*" instruction.
    add_name : str, optional
        The suffix of the instruction tha adds elements to the collection.
        For example: 'add' or 'append'

    Returns
    -------
    transformer : subclass of CodeTransformer
        A new code transformer class that will overload the provided
        literal types.
    """
    typename = type_.__name__
    instrname = 'BUILD_' + typename.upper()
    dict_ = OrderedDict(
        __doc__=dedent(
            """
            A CodeTransformer for overloading {name} instructions.
            """.format(name=instrname)
        )
    )

    try:
        build_instr = getattr(instructions, instrname)
    except AttributeError:
        raise TypeError("type %s is not buildable" % typename)

    if add_name is not None:
        try:
            add_instr = getattr(
                instructions,
                '_'.join((typename, add_name)).upper(),
            )
        except AttributeError:
            TypeError("type %s is not addable" % typename)

        dict_['_start_comprehension'] = pattern(
            build_instr, matchany[var], add_instr,
        )(_start_comprehension)
        dict_['_return_value'] = pattern(
            instructions.RETURN_VALUE, startcodes=(IN_COMPREHENSION,),
        )(_return_value)
    else:
        add_instr = None

    dict_['_build'] = pattern(build_instr)(_build)

    if not typename.endswith('s'):
        typename = typename + 's'

    return type(
        'overloaded_' + typename,
        (overloaded_constants(type_),),
        dict_,
    )


overloaded_slices = overloaded_build(slice)
overloaded_lists = overloaded_build(list, 'append')
overloaded_sets = overloaded_build(set, 'add')


# Add a special method for set overloader.
def transform_consts(self, consts):
    consts = super(overloaded_sets, self).transform_consts(consts)
    return tuple(
        # Always pass a thawed set so mutations can happen inplace.
        self.xform(set(const)) if isinstance(const, frozenset) else const
        for const in consts
    )


overloaded_sets.transform_consts = transform_consts
del transform_consts
frozenset_literals = overloaded_sets(frozenset)


overloaded_tuples = overloaded_build(tuple)


# Add a special method for the tuple overloader.
def transform_consts(self, consts):
    consts = super(overloaded_tuples, self).transform_consts(consts)
    return tuple(
        self.xform(const) if isinstance(const, tuple) else const
        for const in consts
    )


overloaded_tuples.transform_consts = transform_consts
del transform_consts


@instance
[docs]class islice_literals(CodeTransformer): """Transformer that turns slice indexing into an islice object. Examples -------- >>> from codetransformer.transformers.literals import islice_literals >>> @islice_literals ... def f(): ... return map(str, (1, 2, 3, 4))[:2] ... >>> f() <itertools.islice at ...> >>> tuple(f()) ('1', '2') """ @pattern(instructions.BINARY_SUBSCR) def _binary_subscr(self, instr): yield instructions.LOAD_CONST(self._islicer).steal(instr) # TOS = self._islicer # TOS1 = k # TOS2 = m yield instructions.ROT_THREE() # TOS = k # TOS1 = m # TOS2 = self._islicer yield instructions.CALL_FUNCTION(2) # TOS = self._islicer(m, k) @staticmethod def _islicer(m, k): if isinstance(k, slice): return islice(m, k.start, k.stop, k.step) return m[k]