import argparse
import asyncio
import gc
import os.path
import pathlib
import socket
import ssl


PRINT = 0


async def echo_server(loop, address, unix):
    if unix:
        sock = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
    else:
        sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
        sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1)
    sock.bind(address)
    sock.listen(5)
    sock.setblocking(False)
    if PRINT:
        print('Server listening at', address)
    with sock:
        while True:
            client, addr = await loop.sock_accept(sock)
            if PRINT:
                print('Connection from', addr)
            loop.create_task(echo_client(loop, client))


async def echo_client(loop, client):
    try:
        client.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
    except (OSError, NameError):
        pass

    with client:
        while True:
            data = await loop.sock_recv(client, 1000000)
            if not data:
                break
            await loop.sock_sendall(client, data)
    if PRINT:
        print('Connection closed')


async def echo_client_streams(reader, writer):
    sock = writer.get_extra_info('socket')
    try:
        sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1)
    except (OSError, NameError):
        pass
    if PRINT:
        print('Connection from', sock.getpeername())
    while True:
        data = await reader.read(1000000)
        if not data:
            break
        writer.write(data)
    if PRINT:
        print('Connection closed')
    writer.close()


class EchoProtocol(asyncio.Protocol):
    def connection_made(self, transport):
        self.transport = transport

    def connection_lost(self, exc):
        self.transport = None

    def data_received(self, data):
        self.transport.write(data)


class EchoBufferedProtocol(asyncio.BufferedProtocol):
    def connection_made(self, transport):
        self.transport = transport
        # Here the buffer is intended to be copied, so that the outgoing buffer
        # won't be wrongly updated by next read
        self.buffer = bytearray(256 * 1024)

    def connection_lost(self, exc):
        self.transport = None

    def get_buffer(self, sizehint):
        return self.buffer

    def buffer_updated(self, nbytes):
        self.transport.write(self.buffer[:nbytes])


async def print_debug(loop):
    while True:
        print(chr(27) + "[2J")  # clear screen
        loop.print_debug_info()
        await asyncio.sleep(0.5)


if __name__ == '__main__':
    parser = argparse.ArgumentParser()
    parser.add_argument('--uvloop', default=False, action='store_true')
    parser.add_argument('--streams', default=False, action='store_true')
    parser.add_argument('--proto', default=False, action='store_true')
    parser.add_argument('--addr', default='127.0.0.1:25000', type=str)
    parser.add_argument('--print', default=False, action='store_true')
    parser.add_argument('--ssl', default=False, action='store_true')
    parser.add_argument('--buffered', default=False, action='store_true')
    args = parser.parse_args()

    if args.uvloop:
        import uvloop
        loop = uvloop.new_event_loop()
        print('using UVLoop')
    else:
        loop = asyncio.new_event_loop()
        print('using asyncio loop')

    asyncio.set_event_loop(loop)
    loop.set_debug(False)

    if args.print:
        PRINT = 1

    if hasattr(loop, 'print_debug_info'):
        loop.create_task(print_debug(loop))
        PRINT = 0

    unix = False
    if args.addr.startswith('file:'):
        unix = True
        addr = args.addr[5:]
        if os.path.exists(addr):
            os.remove(addr)
    else:
        addr = args.addr.split(':')
        addr[1] = int(addr[1])
        addr = tuple(addr)

    print('serving on: {}'.format(addr))

    server_context = None
    if args.ssl:
        print('with SSL')
        if hasattr(ssl, 'PROTOCOL_TLS'):
            server_context = ssl.SSLContext(ssl.PROTOCOL_TLS)
        else:
            server_context = ssl.SSLContext(ssl.PROTOCOL_SSLv23)
        server_context.load_cert_chain(
            (pathlib.Path(__file__).parent.parent.parent /
                'tests' / 'certs' / 'ssl_cert.pem'),
            (pathlib.Path(__file__).parent.parent.parent /
                'tests' / 'certs' / 'ssl_key.pem'))
        if hasattr(server_context, 'check_hostname'):
            server_context.check_hostname = False
        server_context.verify_mode = ssl.CERT_NONE

    if args.streams:
        if args.proto:
            print('cannot use --stream and --proto simultaneously')
            exit(1)

        if args.buffered:
            print('cannot use --stream and --buffered simultaneously')
            exit(1)

        print('using asyncio/streams')
        if unix:
            coro = asyncio.start_unix_server(echo_client_streams,
                                             addr,
                                             ssl=server_context)
        else:
            coro = asyncio.start_server(echo_client_streams,
                                        *addr,
                                        ssl=server_context)
        srv = loop.run_until_complete(coro)
    elif args.proto:
        if args.streams:
            print('cannot use --stream and --proto simultaneously')
            exit(1)

        if args.buffered:
            print('using buffered protocol')
            protocol = EchoBufferedProtocol
        else:
            print('using simple protocol')
            protocol = EchoProtocol

        if unix:
            coro = loop.create_unix_server(protocol, addr,
                                           ssl=server_context)
        else:
            coro = loop.create_server(protocol, *addr,
                                      ssl=server_context)
        srv = loop.run_until_complete(coro)
    else:
        if args.ssl:
            print('cannot use SSL for loop.sock_* methods')
            exit(1)

        print('using sock_recv/sock_sendall')
        loop.create_task(echo_server(loop, addr, unix))
    try:
        loop.run_forever()
    finally:
        if hasattr(loop, 'print_debug_info'):
            gc.collect()
            print(chr(27) + "[2J")
            loop.print_debug_info()

        loop.close()
