#!/usr/bin/env python

import gzip
import StringIO
import Queue
import threading
import urlparse
import sys
import argparse
import logging
import os

import boto
import boto.s3.connection

logger = logging.getLogger(os.path.basename(__file__))

logging.basicConfig(stream=sys.stderr, level=logging.INFO)


def read_key(key, num_attempts, start=None, end=None):
        logger.info('Reading file %s...' % key.name)
        data = None
        last_exc = None
        for i in xrange(num_attempts):
            try:
                headers = {'Range': 'bytes=%s-%s' % (start, end)} if start is not None and end is not None else None
                data = key.get_contents_as_string(headers=headers)
                if data:
                    if key.name.endswith('.gz'):
                        data = gzip.GzipFile(fileobj=StringIO.StringIO(data)).read()
                    break

            except KeyboardInterrupt, e:
                sys.exit(1)

            except Exception, e:
                last_exc = e
        if data is None:
            logger.error('Failed %s times. Last exception: %s' % (num_attempts, last_exc))
        return data


class MultiKeyProducer(threading.Thread):
    def __init__(self, keys, messages, num_attempts):
        super(MultiKeyProducer, self).__init__()
        self.keys = keys
        self.messages = messages
        self.num_attempts = num_attempts

    def run(self):
        while True:
            if self.keys.empty():
                self.messages.put(None)
                return

            key = self.keys.get()
            data = read_key(key, self.num_attempts)
            self.messages.put(data)
            self.keys.task_done()


def get_files(s3_conn, url_raw, num_attempts, total_threads):
    keys = Queue.Queue()
    messages = Queue.Queue(maxsize=total_threads)

    url = urlparse.urlparse(url_raw)
    bucket = s3_conn.get_bucket(url.netloc)
    bucket_keys = list(bucket.list(url.path.strip('/')))
    if len(bucket_keys) == 0:
        logger.error('The path %s does not contain any keys' % url_raw)

    for k in bucket_keys:
        keys.put(k)

    producers = [MultiKeyProducer(keys, messages, num_attempts) for x in xrange(total_threads)]
    for producer in producers:
        producer.start()

    threads_finished = 0
    while True:
        line = messages.get()
        if line is None:
            threads_finished += 1
            if threads_finished >= total_threads:
                break
        else:
            yield line


def get_file(s3_conn, url_raw, num_attempts, start=None, end=None):
    url = urlparse.urlparse(url_raw)
    bucket = s3_conn.get_bucket(url.netloc)
    key = bucket.get_key(url.path.strip('/'))
    if key is None:
        logger.error('The key %s does not exists' % url_raw)
        return ['']
    return [read_key(key, num_attempts, start=start, end=end)]


if __name__ == '__main__':
    parser = argparse.ArgumentParser()

    parser.add_argument('-r', '--retries', type=int, help='Number of attempts', default=10)
    parser.add_argument('-w', '--workers', type=int, help='Number of concurrent downloads', default=20)
    parser.add_argument('--offset', type=int, help='Bytes to read from when a single key specified', default=None)
    parser.add_argument('--bytes', type=int, help='Bytes to read when a single key specified', default=None)
    parser.add_argument('url', type=str, help='S3 URL or local file path')

    args = parser.parse_args()

    s3_conn = boto.connect_s3()
    log_lines = ['']
    logger.info('Opening %s...' % args.url)

    if args.offset is not None and args.bytes is None:
        logger.error('The argument --offset should be only specified with --bytes')
        sys.exit(1)

    if args.bytes is not None:
        offset = args.offset or 0

        if offset < 0 or args.bytes < 1:
            logger.error('The argument --offset must be >= 0 and --bytes must be >= 1')
            sys.exit(1)

        log_lines = get_file(
            s3_conn=s3_conn,
            url_raw=args.url,
            num_attempts=args.retries,
            start=offset + 1,
            end=offset + args.bytes)
    else:
        log_lines = get_files(
            s3_conn=s3_conn,
            url_raw=args.url,
            num_attempts=args.retries,
            total_threads=(args.workers))

    for log_line in log_lines:
        sys.stdout.write(log_line)

    sys.exit(0)

