#!/usr/bin/env python
# This file is part of Ansible
#
# Ansible is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Ansible is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with Ansible.  If not, see <http://www.gnu.org/licenses/>.

DOCUMENTATION = """
---
module: ec2_subnet
short_description: Create or delete AWS Subnets
description:
  - Can create or delete AwS Subnets
version_added: "1.8"
author: Edward Zarecor
options:
  state:
    description:
      - create, update or delete the subnet
    required: true
    choices: ['present', 'absent']
  name:
    description:
      - Unique name for subnet
    required: true
  cidr_block:
    description:
      - The cidr block of the subnet
    aliases: ['cidr']
  availability_zone
    description:
      - The availability zone of the subnet
    aliases: ['az']
  vpc_id:
    description:
      - The VPC that this acl belongs to
    required: true
    default: null
extends_documentation_fragment: aws
"""

EXAMPLES = '''
'''

from ansible.module_utils.basic import *
from ansible.module_utils.ec2 import *
import sys
try:
    import boto.vpc
except ImportError:
    print "failed=True msg='boto required for this module'"
    sys.exit(1)

from boto.exception import NoAuthHandlerFound


class NonUniqueSubnetSpecification(Exception):
    pass


class SubnetManager:

    def __init__(self, connection, vpc_id, cidr_block, az, name, route_table_id, network_acl_id, tags=[]):

        self.connection = connection
        self.vpc_id = vpc_id
        self.cidr_block = cidr_block
        self.az = az
        self.name = name
        self.route_table_id = route_table_id
        self.network_acl_id = network_acl_id
        self.tags = tags
        self.subnet = None

    def get_subnet(self):

        if not self.subnet:

            subnets = self.connection. \
                get_all_subnets(filters={"vpc_id": self.vpc_id,
                                         "cidr_block": self.cidr_block,
                                         "availability_zone": self.az})
            if len(subnets) > 1:
                message = "Subnet specifier of vpc_id {vpc_id}, cidr_block {cidr_block} " \
                          "and az {az}, return more than one result".format(
                    vpc_id=self.vpc_id, cidr_block=self.cidr_block,az=self.az)
                raise NonUniqueSubnetSpecification(message)
            elif len(subnets) == 1:
                self.subnet = subnets[0]

        return self.subnet

    def present(self):

        if self.get_subnet():
            changed = self.update_subnet()
        else:
            changed = self.create_subnet()

        results = dict(changed=changed,
                       subnet_id=self.subnet.id,
                       subnet_name=self.name,
                       vpc_id=self.vpc_id)

        return results

    def create_subnet(self):
        changed = True
        self.subnet = self.connection.create_subnet(self.vpc_id, self.cidr_block, availability_zone=self.az)
        self.do_tags()
        self.connection.associate_route_table(self.route_table_id, self.subnet.id)
        if self.network_acl_id:
            self.connection.associate_network_acl(self.network_acl_id, self.subnet.id)
        return changed

    def update_subnet(self):
        changed = False
        self.do_tags()

        results = self.connection.get_all_route_tables(
                    filters={'association.subnet_id': self.subnet.id, 'vpc_id': self.vpc_id})

        if len(results) == 1:
            route_table = results[0]
            assoc = self.get_association_from_route_table(route_table, self.subnet)
            if assoc.route_table_id != self.route_table_id:
                self.connection.replace_route_table_association_with_assoc(assoc.id, self.route_table_id)
                changed = True
        elif len(results) == 0:
            # unlikely unless manual monkeying around
            self.connection.associate_route_table(self.route_table_id, self.subnet.id)
            changed == True

        if self.network_acl_id:
            self.connection.associate_network_acl(self.network_acl_id, self.subnet.id)

        # acl_results = self.connection.get_all_network_acls(
        #     filters={'association.subnet_id': self.subnet.id, 'vpc_id': self.vpc_id})
        #
        # if len(acl_results) == 1:
        #     acl = acl_results[0]
        #
        #     if acl.id != self.network_acl_id:
        #         self.connection.disassociate_network_acl

        return changed

    def absent(self):
        changed = self.connection.delete_subnet(self.subnet.id)
        return dict(changed=changed)

    def get_association_from_route_table(self, route_table, subnet):
        target = None
        for assoc in route_table.associations:
            if assoc.subnet_id == subnet.id:
                target = assoc
                break

        return target

    def do_tags(self):
        """
        Utility that creates all tags including the Name tag which is treated
        as a first class params as a convenience.  Currently updates
        existing tags, as the API overwrites them, but does not remove
        orphans.
        :return: None
        """
        tags = {'Name': self.name}
        if self.tags:
            for tag in self.tags:
                tags[tag['key']] = tag['value']

        self.subnet.add_tags(tags)


def main():
    argument_spec = ec2_argument_spec()
    argument_spec.update(
        dict(
            name=dict(required=True, type='str'),
            state=dict(default='present', choices=['present', 'absent']),
            vpc_id=dict(required=True, type='str'),
            cidr_block=dict(required=True, type='str', aliases=['cidr']),
            az=dict(required=True, type='str'),
            route_table_id=dict(required=True, type='str'),
            network_acl_id=dict(type='str'),
            tags=dict(type='list'),
        )
    )

    module = AnsibleModule(argument_spec=argument_spec)
    ec2_url, aws_access_key, aws_secret_key, region = get_ec2_creds(module)
    profile = module.params.get('profile')
    if region:
        try:
            connection = boto.vpc.connect_to_region(region, profile_name=profile)
        except boto.exception.NoAuthHandlerFound, e:
            module.fail_json(msg=str(e))
    else:
        module.fail_json(msg="region must be specified")

    vpc_id = module.params.get('vpc_id')
    cidr_block = module.params.get('cidr_block')
    az = module.params.get('az')
    name = module.params.get('name')
    route_table_id = module.params.get('route_table_id')
    network_acl_id = module.params.get('network_acl_id')
    tags = module.params.get('tags')

    manager = SubnetManager(connection, vpc_id, cidr_block, az, name, route_table_id, network_acl_id, tags)

    state = module.params.get('state')

    if state == 'present':
        results = manager.present()
    elif state == 'absent':
        results = manager.absent()
    else:
        raise Exception("Unexpected value for state {0}".format(state))

    module.exit_json(**results)

main()