import stix
import stix.ttp.malware_instance
from stix.ttp.malware_instance import MalwareInstance
import stix.bindings.extensions.malware.maec_4_1 as ext_binding
from lxml import etree
from itertools import izip

# TODO: enable use of  python-maec here

class MAECInstance(MalwareInstance):
    _binding = ext_binding
    _binding_class = _binding.MAEC4_1InstanceType
    _namespace = 'http://stix.mitre.org/extensions/Malware#MAEC4.1-1'
    _XSI_TYPE = "maecInstance:MAEC4.1InstanceType"

    def __init__(self, maec=None):
        super(MAECInstance, self).__init__()
        self.maec = maec
        
    @property
    def maec(self):
        return self._maec
    
    @maec.setter
    def maec(self, value):
        if not value:
            self._maec = None
            return
        elif isinstance(value, etree._ElementTree):
            tree = value
        elif isinstance(value, etree._Element):
            tree = etree.ElementTree(value)
        else:
            raise ValueError('maec must be instance of lxml.etree._Element '
                             'or lxml.etree._ElementTree')
        
        root = tree.getroot()
        expected_node_tag = "{%s}MAEC" % (self._namespace)
        if root.tag != expected_node_tag:
            ns_maec = "http://maec.mitre.org/XMLSchema/maec-package-2"
            node_ns = root.tag[1:].split("}")[0] if root.tag.startswith("{") else None
            if node_ns == ns_maec:
                # attempt to cast
                root.tag = expected_node_tag
            else:
                raise ValueError("Cannot set maec property. Expected tag %s found %s" 
                                 % (expected_node_tag, root.tag))
        
        self.__input_namespaces__ = {}
        for alias,ns in root.nsmap.iteritems():
            self.__input_namespaces__[ns] = alias
            
        self.__input_schemalocations__ = {}
        schemaloc_str = root.attrib.get('{http://www.w3.org/2001/XMLSchema-instance}schemaLocation')
        if schemaloc_str:
            pairs = izip(*[iter(schemaloc_str.split())]*2)
            self.__input_schemalocations__ = dict(pairs)
                
        self._maec = tree
        
    @classmethod
    def from_obj(cls, obj, return_obj=None):
        if not obj:
            return None
        if not return_obj:
            return_obj = cls()
        
        super(MAECInstance, cls).from_obj(obj, return_obj)
        return_obj.maec = obj.get_MAEC()
        return return_obj
    
    def to_obj(self, return_obj=None):
        if not return_obj:
            return_obj = self._binding_class()
            
        super(MAECInstance, self).to_obj(return_obj)
        return_obj.set_MAEC(self.maec) 
        return return_obj
    
    @classmethod
    def from_dict(cls, d, return_obj=None):
        if not d:
            return None
        if not return_obj:
            return_obj = cls()
            
        super(MAECInstance, cls).from_dict(d, return_obj)
        if 'maec' in d:
            parser = etree.ETCompatXMLParser(huge_tree=True)
            return_obj.maec = etree.parse(d['maec'], parser=parser)
        
        return return_obj    
    

stix.ttp.malware_instance.add_extension(MAECInstance)
