# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license

# A derivative of a dnspython VersionedZone and related classes, using a BTreeDict and
# a separate per-version delegation index.  These additions let us
#
# 1) Do efficient CoW versioning (useful for future online updates).
# 2) Maintain sort order
# 3) Allow delegations to be found easily
# 4) Handle glue
# 5) Add Node flags ORIGIN, DELEGATION, and GLUE whenever relevant.  The ORIGIN
#    flag is set at the origin node, the DELEGATION FLAG is set at delegation
#    points, and the GLUE flag is set on nodes beneath delegation points.

import enum
from dataclasses import dataclass
from typing import Callable, MutableMapping, Tuple, cast

import dns.btree
import dns.immutable
import dns.name
import dns.node
import dns.rdataclass
import dns.rdataset
import dns.rdatatype
import dns.versioned
import dns.zone


class NodeFlags(enum.IntFlag):
    ORIGIN = 0x01
    DELEGATION = 0x02
    GLUE = 0x04


class Node(dns.node.Node):
    __slots__ = ["flags", "id"]

    def __init__(self, flags: NodeFlags | None = None):
        super().__init__()
        if flags is None:
            # We allow optional flags rather than a default
            # as pyright doesn't like assigning a literal 0
            # to flags.
            flags = NodeFlags(0)
        self.flags = flags
        self.id = 0

    def is_delegation(self):
        return (self.flags & NodeFlags.DELEGATION) != 0

    def is_glue(self):
        return (self.flags & NodeFlags.GLUE) != 0

    def is_origin(self):
        return (self.flags & NodeFlags.ORIGIN) != 0

    def is_origin_or_glue(self):
        return (self.flags & (NodeFlags.ORIGIN | NodeFlags.GLUE)) != 0


@dns.immutable.immutable
class ImmutableNode(Node):
    def __init__(self, node: Node):
        super().__init__()
        self.id = node.id
        self.rdatasets = tuple(  # type: ignore
            [dns.rdataset.ImmutableRdataset(rds) for rds in node.rdatasets]
        )
        self.flags = node.flags

    def find_rdataset(
        self,
        rdclass: dns.rdataclass.RdataClass,
        rdtype: dns.rdatatype.RdataType,
        covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
        create: bool = False,
    ) -> dns.rdataset.Rdataset:
        if create:
            raise TypeError("immutable")
        return super().find_rdataset(rdclass, rdtype, covers, False)

    def get_rdataset(
        self,
        rdclass: dns.rdataclass.RdataClass,
        rdtype: dns.rdatatype.RdataType,
        covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
        create: bool = False,
    ) -> dns.rdataset.Rdataset | None:
        if create:
            raise TypeError("immutable")
        return super().get_rdataset(rdclass, rdtype, covers, False)

    def delete_rdataset(
        self,
        rdclass: dns.rdataclass.RdataClass,
        rdtype: dns.rdatatype.RdataType,
        covers: dns.rdatatype.RdataType = dns.rdatatype.NONE,
    ) -> None:
        raise TypeError("immutable")

    def replace_rdataset(self, replacement: dns.rdataset.Rdataset) -> None:
        raise TypeError("immutable")

    def is_immutable(self) -> bool:
        return True


class Delegations(dns.btree.BTreeSet[dns.name.Name]):
    def get_delegation(self, name: dns.name.Name) -> Tuple[dns.name.Name | None, bool]:
        """Get the delegation applicable to *name*, if it exists.

        If there delegation, then return a tuple consisting of the name of
        the delegation point, and a boolean which is `True` if the name is a proper
        subdomain of the delegation point, and `False` if it is equal to the delegation
        point.
        """
        cursor = self.cursor()
        cursor.seek(name, before=False)
        prev = cursor.prev()
        if prev is None:
            return None, False
        cut = prev.key()
        reln, _, _ = name.fullcompare(cut)
        is_subdomain = reln == dns.name.NameRelation.SUBDOMAIN
        if is_subdomain or reln == dns.name.NameRelation.EQUAL:
            return cut, is_subdomain
        else:
            return None, False

    def is_glue(self, name: dns.name.Name) -> bool:
        """Is *name* glue, i.e. is it beneath a delegation?"""
        cursor = self.cursor()
        cursor.seek(name, before=False)
        cut, is_subdomain = self.get_delegation(name)
        if cut is None:
            return False
        return is_subdomain


class WritableVersion(dns.zone.WritableVersion):
    def __init__(self, zone: dns.zone.Zone, replacement: bool = False):
        super().__init__(zone, True)
        if not replacement:
            assert isinstance(zone, dns.versioned.Zone)
            version = zone._versions[-1]
            self.nodes: dns.btree.BTreeDict[dns.name.Name, Node] = dns.btree.BTreeDict(
                original=version.nodes  # type: ignore
            )
            self.delegations = Delegations(original=version.delegations)  # type: ignore
        else:
            self.delegations = Delegations()

    def _is_origin(self, name: dns.name.Name) -> bool:
        # Assumes name has already been validated (and thus adjusted to the right
        # relativity too)
        if self.zone.relativize:
            return name == dns.name.empty
        else:
            return name == self.zone.origin

    def _maybe_cow_with_name(
        self, name: dns.name.Name
    ) -> Tuple[dns.node.Node, dns.name.Name]:
        (node, name) = super()._maybe_cow_with_name(name)
        node = cast(Node, node)
        if self._is_origin(name):
            node.flags |= NodeFlags.ORIGIN
        elif self.delegations.is_glue(name):
            node.flags |= NodeFlags.GLUE
        return (node, name)

    def update_glue_flag(self, name: dns.name.Name, is_glue: bool) -> None:
        cursor = self.nodes.cursor()  # type: ignore
        cursor.seek(name, False)
        updates = []
        while True:
            elt = cursor.next()
            if elt is None:
                break
            ename = elt.key()
            if not ename.is_subdomain(name):
                break
            node = cast(dns.node.Node, elt.value())
            if ename not in self.changed:
                new_node = self.zone.node_factory()
                new_node.id = self.id  # type: ignore
                new_node.rdatasets.extend(node.rdatasets)
                self.changed.add(ename)
                node = new_node
            assert isinstance(node, Node)
            if is_glue:
                node.flags |= NodeFlags.GLUE
            else:
                node.flags &= ~NodeFlags.GLUE
            # We don't update node here as any insertion could disturb the
            # btree and invalidate our cursor.  We could use the cursor in a
            # with block and avoid this, but it would do a lot of parking and
            # unparking so the deferred update mode may still be better.
            updates.append((ename, node))
        for ename, node in updates:
            self.nodes[ename] = node

    def delete_node(self, name: dns.name.Name) -> None:
        name = self._validate_name(name)
        node = self.nodes.get(name)
        if node is not None:
            if node.is_delegation():  # type: ignore
                self.delegations.discard(name)
                self.update_glue_flag(name, False)
            del self.nodes[name]
            self.changed.add(name)

    def put_rdataset(
        self, name: dns.name.Name, rdataset: dns.rdataset.Rdataset
    ) -> None:
        (node, name) = self._maybe_cow_with_name(name)
        if (
            rdataset.rdtype == dns.rdatatype.NS and not node.is_origin_or_glue()  # type: ignore
        ):
            node.flags |= NodeFlags.DELEGATION  # type: ignore
            if name not in self.delegations:
                self.delegations.add(name)
                self.update_glue_flag(name, True)
        node.replace_rdataset(rdataset)

    def delete_rdataset(
        self,
        name: dns.name.Name,
        rdtype: dns.rdatatype.RdataType,
        covers: dns.rdatatype.RdataType,
    ) -> None:
        (node, name) = self._maybe_cow_with_name(name)
        if rdtype == dns.rdatatype.NS and name in self.delegations:  # type: ignore
            node.flags &= ~NodeFlags.DELEGATION  # type: ignore
            self.delegations.discard(name)  # type: ignore
            self.update_glue_flag(name, False)
        node.delete_rdataset(self.zone.rdclass, rdtype, covers)
        if len(node) == 0:
            del self.nodes[name]


@dataclass(frozen=True)
class Bounds:
    name: dns.name.Name
    left: dns.name.Name
    right: dns.name.Name | None
    closest_encloser: dns.name.Name
    is_equal: bool
    is_delegation: bool

    def __str__(self):
        if self.is_equal:
            op = "="
        else:
            op = "<"
        if self.is_delegation:
            zonecut = " zonecut"
        else:
            zonecut = ""
        return (
            f"{self.left} {op} {self.name} < {self.right}{zonecut}; "
            f"{self.closest_encloser}"
        )


@dns.immutable.immutable
class ImmutableVersion(dns.zone.Version):
    def __init__(self, version: dns.zone.Version):
        if not isinstance(version, WritableVersion):
            raise ValueError(
                "a dns.btreezone.ImmutableVersion requires a "
                "dns.btreezone.WritableVersion"
            )
        super().__init__(version.zone, True)
        self.id = version.id
        self.origin = version.origin
        for name in version.changed:
            node = version.nodes.get(name)
            if node:
                version.nodes[name] = ImmutableNode(node)
        # the cast below is for mypy
        self.nodes = cast(MutableMapping[dns.name.Name, dns.node.Node], version.nodes)
        self.nodes.make_immutable()  # type: ignore
        self.delegations = version.delegations
        self.delegations.make_immutable()

    def bounds(self, name: dns.name.Name | str) -> Bounds:
        """Return the 'bounds' of *name* in its zone.

        The bounds information is useful when making an authoritative response, as
        it can be used to determine whether the query name is at or beneath a delegation
        point.  The other data in the ``Bounds`` object is useful for making on-the-fly
        DNSSEC signatures.

        The left bound of *name* is *name* itself if it is in the zone, or the greatest
        predecessor which is in the zone.

        The right bound of *name* is the least successor of *name*, or ``None`` if
        no name in the zone is greater than *name*.

        The closest encloser of *name* is *name* itself, if *name* is in the zone;
        otherwise it is the name with the largest number of labels in common with
        *name* that is in the zone, either explicitly or by the implied existence
        of empty non-terminals.

        The bounds *is_equal* field is ``True`` if and only if *name* is equal to
        its left bound.

        The bounds *is_delegation* field is ``True`` if and only if the left bound is a
        delegation point.
        """
        assert self.origin is not None
        # validate the origin because we may need to relativize
        origin = self.zone._validate_name(self.origin)
        name = self.zone._validate_name(name)
        cut, _ = self.delegations.get_delegation(name)
        if cut is not None:
            target = cut
            is_delegation = True
        else:
            target = name
            is_delegation = False
        c = cast(dns.btree.BTreeDict, self.nodes).cursor()
        c.seek(target, False)
        left = c.prev()
        assert left is not None
        c.next()  # skip over left
        while True:
            right = c.next()
            if right is None or not right.value().is_glue():
                break
        left_comparison = left.key().fullcompare(name)
        if right is not None:
            right_key = right.key()
            right_comparison = right_key.fullcompare(name)
        else:
            right_comparison = (
                dns.name.NAMERELN_COMMONANCESTOR,
                -1,
                len(origin),
            )
            right_key = None
        closest_encloser = dns.name.Name(
            name[-max(left_comparison[2], right_comparison[2]) :]
        )
        return Bounds(
            name,
            left.key(),
            right_key,
            closest_encloser,
            left_comparison[0] == dns.name.NameRelation.EQUAL,
            is_delegation,
        )


class Zone(dns.versioned.Zone):
    node_factory: Callable[[], dns.node.Node] = Node
    map_factory: Callable[[], MutableMapping[dns.name.Name, dns.node.Node]] = cast(
        Callable[[], MutableMapping[dns.name.Name, dns.node.Node]],
        dns.btree.BTreeDict[dns.name.Name, Node],
    )
    writable_version_factory: (
        Callable[[dns.zone.Zone, bool], dns.zone.Version] | None
    ) = WritableVersion
    immutable_version_factory: Callable[[dns.zone.Version], dns.zone.Version] | None = (
        ImmutableVersion
    )
