#!/usr/bin/env python
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.

DOCUMENTATION = """
---
module: ec2_acl
short_description: Create or delete AWS Network ACLs.
description:
  - Can create or delete AwS Network ACLs.
version_added: "1.8"
author: Edward Zarecor
options:
  state:
    description:
      - create, update or delete the acl
    required: true
    choices: ['present', 'absent']
  name:
    description:
      - Unique name for acl
    required: true
  vpc_id:
    description:
      - The VPC that this acl belongs to
    required: true
    default: null
extends_documentation_fragment: aws
"""

EXAMPLES = '''
- ec2_acl:
    name: public-acls
    state: present
    vpc_id: 'vpc-abababab'

'''

from ansible.module_utils.basic import *
from ansible.module_utils.ec2 import *
import sys
try:
    import boto.vpc
except ImportError:
    print "failed=True msg={0}".format(sys.executable)
    #print "failed=True msg='boto required for this module'"
    sys.exit(1)

from boto.exception import NoAuthHandlerFound

PROTOCOL_NUMBERS = {"ICMP": 1, "TCP": 6, "UDP": 17, "ALL": -1 }


class DuplicateAclError(Exception):
    pass


class ACLManager():

    def __init__(self, connection, vpc_id, acl_name, rules, tags=[]):
        self.connection = connection
        self.vpc_id = vpc_id
        self.acl_name = acl_name
        self.rules = rules
        self.tags = tags
        self.acl = None

    def get_acl(self):

        if not self.acl:
            results = self.connection.get_all_network_acls(filters={"vpc_id": self.vpc_id, "tag:Name": self.acl_name})

            if len(results) == 1:
                self.acl = results[0]
            elif len(results) > 1:
                raise DuplicateAclError("Found multiple network acls name {0} in vpc with id {1}".
                                        format(self.acl_name, self.vpc_id))
            else:
                # Does exist yet
                pass

        return self.acl

    def create_acl(self):
        self.acl = self.connection.create_network_acl(self.vpc_id)
        changed = True
        self.do_tags()
        return changed

    def update_acl(self):
        changed = False
        self.update_rules()
        self.do_tags()
        return changed

    # TODO refactor out repitition
    def update_rules(self):

        current_ingress = [x.rule_number for x in self.acl.network_acl_entries if x.egress == 'false']
        current_egress = [x.rule_number for x in self.acl.network_acl_entries if x.egress == 'true']

        modified_ingress = []
        modified_egress = []

        for rule in self.rules:
            egress = True if rule['type'] == "egress" else False
            protocol = PROTOCOL_NUMBERS[rule['protocol'].upper()]

            if not egress:
                if rule['number'] not in current_ingress:
                    # new rule
                    self.connection.create_network_acl_entry(
                        self.acl.id,
                        rule['number'],
                        protocol,
                        rule['rule_action'],
                        rule['cidr_block'],
                        egress=egress,
                        port_range_from=rule['from_port'],
                        port_range_to=rule['to_port'])
                else:
                    # blindly replace rather than attempting
                    # to determine in the entry has changed
                    modified_ingress.append(rule['number'])
                    self.connection.replace_network_acl_entry (
                        self.acl.id,
                        rule['number'],
                        protocol,
                        rule['rule_action'],
                        rule['cidr_block'],
                        egress=egress,
                        port_range_from=rule['from_port'],
                        port_range_to=rule['to_port'])
            else:
                if rule['number'] not in current_egress:
                    # new rule
                    self.connection.create_network_acl_entry(
                        self.acl.id,
                        rule['number'],
                        protocol,
                        rule['rule_action'],
                        rule['cidr_block'],
                        egress=egress,
                        port_range_from=rule['from_port'],
                        port_range_to=rule['to_port'])
                else:
                    # blindly replace rather than attempting
                    # to determine in the entry has changed
                    modified_egress.append(rule['number'])
                    self.connection.replace_network_acl_entry (
                        self.acl.id,
                        rule['number'],
                        protocol,
                        rule['rule_action'],
                        rule['cidr_block'],
                        egress=egress,
                        port_range_from=rule['from_port'],
                        port_range_to=rule['to_port'])

        removed_ingress_rule_numbers = [ c for c in current_ingress if c not in modified_ingress ]
        removed_egress_rule_numbers = [ c for c in current_egress if c not in modified_egress ]

        for number in removed_ingress_rule_numbers:
            n = int(number)
            # reserved range for AWS
            if n <  32767:
                self.connection.delete_network_acl_entry(self.acl.id, n, False)

        for number in removed_egress_rule_numbers:
            n = int(number)
            # reserved range for AWS
            if n <  32767:
                self.connection.delete_network_acl_entry(self.acl.id, n, True)


    def create_rules(self):
        if self.rules is None:
            return
        for rule in self.rules:
            egress = True if rule['type'] == "egress" else False
            protocol = PROTOCOL_NUMBERS[rule['protocol'].upper()]
            self.connection.create_network_acl_entry(
                self.acl.id,
                rule['number'],
                protocol,
                rule['rule_action'],
                rule['cidr_block'],
                egress=egress,
                port_range_from=rule['from_port'],
                port_range_to=rule['to_port'])

    def do_tags(self):

        tags = {'Name': self.acl_name}
        if self.tags:
            for tag in self.tags:
                tags[tag['key']] = tag['value']
        self.get_acl().add_tags(tags)

    def present(self):

        existing = self.get_acl()

        if not existing:
            changed = self.create_acl()
            self.create_rules()
        else:
            changed = self.update_acl()

        results = dict(changed=changed,
                       id=self.acl.id,
                       name=self.acl_name,
                       entries=self.rules)

        return results

    def absent(self):
        acl = self.get_acl()
        changed = False

        if acl:
            changed = self.connection.delete_network_acl(acl.id)
        results = dict(changed=changed,
                       id=self.acl.id,
                       name=self.acl_name)

        return results


def main():
    argument_spec = ec2_argument_spec()
    argument_spec.update(
        dict(
            name=dict(required=True, type='str'),
            state=dict(default='present', choices=['present', 'absent']),
            vpc_id=dict(required=True, type='str'),
            rules=dict(type='list'),
            tags=dict(type='list'),
        )
    )

    module = AnsibleModule(argument_spec=argument_spec)

    ec2_url, aws_access_key, aws_secret_key, region = get_ec2_creds(module)
    profile = module.params.get('profile')
    if region:
        try:
            connection = boto.vpc.connect_to_region(region, profile_name=profile)
        except boto.exception.NoAuthHandlerFound, e:
            module.fail_json(msg=str(e))
    else:
        module.fail_json(msg="region must be specified")

    vpc_id = module.params.get('vpc_id')
    acl_name = module.params.get('name')
    rules_in = module.params.get('rules')
    tags = module.params.get('tags')

    manager = ACLManager(connection, vpc_id, acl_name, rules_in, tags)

    state = module.params.get('state')

    results = dict()

    if state == 'present':
        results = manager.present()
    elif state == 'absent':
        results = manager.absent()

    module.exit_json(**results)

main()