###############################################################################
#
#   Agora Portfolio & Risk Management System
#
#   Copyright 2015 Carlo Sbraccia
#
#   Licensed under the Apache License, Version 2.0 (the "License");
#   you may not use this file except in compliance with the License.
#   You may obtain a copy of the License at
#
#       http://www.apache.org/licenses/LICENSE-2.0
#
#   Unless required by applicable law or agreed to in writing, software
#   distributed under the License is distributed on an "AS IS" BASIS,
#   WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#   See the License for the specific language governing permissions and
#   limitations under the License.
#
###############################################################################

from onyx.core import (RDate, Structure, ValueType, GetVal,
                       IntField, FloatField, DateField, StringField,
                       SelectField, ReferenceField)

from agora.corelibs.date_functions import DateOffset
from agora.corelibs.tradable_api import (FwdTradableObj,
                                         AddByInference, HashStoredAttrs)

from agora.tradablas.ufo_commod_forward import CommodForward
from agora.tradables.ufo_forward_cash import ForwardCash
from agora.tradables.ufo_commod_nrby import AVERAGING_TYPES


###############################################################################
class CommodSwap(FwdTradableObj):
    """
    Tradable class that represents a commodity swap contract.
    """
    Asset = ReferenceField(obj_type="CommodAsset")
    AvgStartDate = DateField()
    AvgEndDate = DateField()
    AvgType = SelectField(options=AVERAGING_TYPES)
    RollType = IntField(default=0)
    Quantity = IntField()
    FixedPrice = FloatField()
    Denominated = ReferenceField(obj_type="Currency", default="USD")
    PaymentRule = StringField(default="+5b")

    # -------------------------------------------------------------------------
    @ValueType()
    def Leaves(self, graph):
        asset = graph(self, "Asset")
        qty = graph(self, "Quantity")
        price = graph(self, "FixedPrice")
        sd = graph(self, "AvgStartDate")
        ed = graph(self, "AvgEndDate")
        one_month = RDate("+1m")
        den = graph(self, "Denominated")
        payment_rule = graph(self, "PaymentDateRule")
        cal = graph(asset, "HolidayCalendar")
        avg_type = graph(self, "AvgType")
        roll_type = graph(self, "RollType")

        # --- settlement on a monthly basis
        leaves = Structure()
        while sd < ed:
            mth_ed = min(ed, DateOffset(sd, "+e"))
            fwd_info = {
                "Asset": asset,
                "AvgStartDate": sd,
                "AvgEndDate": mth_ed,
                "AvgType": avg_type,
                "RollType": roll_type,
            }
            sec = AddByInference(CommodForward(**fwd_info), True)
            leaves[sec.Name] = qty

            # --- cash leg
            cash_info = {
                "Currency": den,
                "PaymentDate": DateOffset(mth_ed, payment_rule, cal),
            }
            sec = AddByInference(ForwardCash(**cash_info), True)
            leaves[sec.Name] = -qty*price

            sd += one_month

        return leaves

    # -------------------------------------------------------------------------
    @ValueType()
    def UndiscountedValue(self, graph):
        fx = graph("{0:3s}/USD".format(graph(self, "Denominated")), "Spot")
        return graph(self, "UndiscountedValueUSD") / fx

    # -------------------------------------------------------------------------
    @ValueType()
    def MktVal(self, graph):
        fx = graph("{0:3s}/USD".format(graph(self, "Denominated")), "Spot")
        return graph(self, "MktValUSD") / fx

    # -------------------------------------------------------------------------
    @ValueType()
    def UndiscountedValueUSD(self, graph):
        val = 0.0
        for sec, qty in graph(self, "Leaves").items():
            val += qty*graph(sec, "UndiscountedValueUSD")
        return val

    # -------------------------------------------------------------------------
    @ValueType()
    def MktValUSD(self, graph):
        val = 0.0
        for sec, qty in graph(self, "Leaves").items():
            val += qty*graph(sec, "MktValUSD")
        return val

    # -------------------------------------------------------------------------
    @ValueType()
    def ExpirationDate(self, graph):
        securities = graph(self, "Leaves").keys()
        return max([GetVal(sec, "ExpirationDate") for sec in securities])

    # -------------------------------------------------------------------------
    @ValueType()
    def NextTransactionDate(self, graph):
        securities = graph(self, "Leaves").keys()
        return max([GetVal(sec, "NextTransactionDate") for sec in securities])

    # -------------------------------------------------------------------------
    @ValueType()
    def NextTransactionEvent(self, graph):
        return "Settlement"

    # -------------------------------------------------------------------------
    @ValueType()
    def NextTransactionSecurity(self, graph):
        return None

    # -------------------------------------------------------------------------
    @property
    def ImpliedName(self, graph):
        mkt = GetVal(self.Asset, "Market")
        sym = GetVal(self.Asset, "Symbol")
        return ("CmdSWP {0:s} {1:s} {2:8s} "
                "{0:2d").format(mkt, sym, HashStoredAttrs(self, 8))
