# Copyright (C) Dnspython Contributors, see LICENSE for text of ISC license

# Copyright (C) 2003-2017 Nominum, Inc.
#
# Permission to use, copy, modify, and distribute this software and its
# documentation for any purpose with or without fee is hereby granted,
# provided that the above copyright notice and this permission notice
# appear in all copies.
#
# THE SOFTWARE IS PROVIDED "AS IS" AND NOMINUM DISCLAIMS ALL WARRANTIES
# WITH REGARD TO THIS SOFTWARE INCLUDING ALL IMPLIED WARRANTIES OF
# MERCHANTABILITY AND FITNESS. IN NO EVENT SHALL NOMINUM BE LIABLE FOR
# ANY SPECIAL, DIRECT, INDIRECT, OR CONSEQUENTIAL DAMAGES OR ANY DAMAGES
# WHATSOEVER RESULTING FROM LOSS OF USE, DATA OR PROFITS, WHETHER IN AN
# ACTION OF CONTRACT, NEGLIGENCE OR OTHER TORTIOUS ACTION, ARISING OUT
# OF OR IN CONNECTION WITH THE USE OR PERFORMANCE OF THIS SOFTWARE.

"""Talk to a DNS server."""

import base64
import contextlib
import enum
import errno
import os
import random
import selectors
import socket
import struct
import time
import urllib.parse
from typing import Any, Callable, Dict, Optional, Tuple, cast

import dns._features
import dns._tls_util
import dns.exception
import dns.inet
import dns.message
import dns.name
import dns.quic
import dns.rdata
import dns.rdataclass
import dns.rdatatype
import dns.transaction
import dns.tsig
import dns.xfr

try:
    import ssl
except ImportError:
    import dns._no_ssl as ssl  # type: ignore


def _remaining(expiration):
    if expiration is None:
        return None
    timeout = expiration - time.time()
    if timeout <= 0.0:
        raise dns.exception.Timeout
    return timeout


def _expiration_for_this_attempt(timeout, expiration):
    if expiration is None:
        return None
    return min(time.time() + timeout, expiration)


_have_httpx = dns._features.have("doh")
if _have_httpx:
    import httpcore._backends.sync
    import httpx

    _CoreNetworkBackend = httpcore.NetworkBackend
    _CoreSyncStream = httpcore._backends.sync.SyncStream

    class _NetworkBackend(_CoreNetworkBackend):
        def __init__(self, resolver, local_port, bootstrap_address, family):
            super().__init__()
            self._local_port = local_port
            self._resolver = resolver
            self._bootstrap_address = bootstrap_address
            self._family = family

        def connect_tcp(
            self, host, port, timeout=None, local_address=None, socket_options=None
        ):  # pylint: disable=signature-differs
            addresses = []
            _, expiration = _compute_times(timeout)
            if dns.inet.is_address(host):
                addresses.append(host)
            elif self._bootstrap_address is not None:
                addresses.append(self._bootstrap_address)
            else:
                timeout = _remaining(expiration)
                family = self._family
                if local_address:
                    family = dns.inet.af_for_address(local_address)
                answers = self._resolver.resolve_name(
                    host, family=family, lifetime=timeout
                )
                addresses = answers.addresses()
            for address in addresses:
                af = dns.inet.af_for_address(address)
                if local_address is not None or self._local_port != 0:
                    if local_address is None:
                        local_address = "0.0.0.0"
                    source = dns.inet.low_level_address_tuple(
                        (local_address, self._local_port), af
                    )
                else:
                    source = None
                try:
                    sock = make_socket(af, socket.SOCK_STREAM, source)
                    attempt_expiration = _expiration_for_this_attempt(2.0, expiration)
                    _connect(
                        sock,
                        dns.inet.low_level_address_tuple((address, port), af),
                        attempt_expiration,
                    )
                    return _CoreSyncStream(sock)
                except Exception:
                    pass
            raise httpcore.ConnectError

        def connect_unix_socket(
            self, path, timeout=None, socket_options=None
        ):  # pylint: disable=signature-differs
            raise NotImplementedError

    class _HTTPTransport(httpx.HTTPTransport):  # pyright: ignore
        def __init__(
            self,
            *args,
            local_port=0,
            bootstrap_address=None,
            resolver=None,
            family=socket.AF_UNSPEC,
            **kwargs,
        ):
            if resolver is None and bootstrap_address is None:
                # pylint: disable=import-outside-toplevel,redefined-outer-name
                import dns.resolver

                resolver = dns.resolver.Resolver()
            super().__init__(*args, **kwargs)
            self._pool._network_backend = _NetworkBackend(
                resolver, local_port, bootstrap_address, family
            )

else:

    class _HTTPTransport:  # type: ignore
        def __init__(
            self,
            *args,
            local_port=0,
            bootstrap_address=None,
            resolver=None,
            family=socket.AF_UNSPEC,
            **kwargs,
        ):
            pass

        def connect_tcp(self, host, port, timeout, local_address):
            raise NotImplementedError


have_doh = _have_httpx


def default_socket_factory(
    af: socket.AddressFamily | int,
    kind: socket.SocketKind,
    proto: int,
) -> socket.socket:
    return socket.socket(af, kind, proto)


# Function used to create a socket.  Can be overridden if needed in special
# situations.
socket_factory: Callable[
    [socket.AddressFamily | int, socket.SocketKind, int], socket.socket
] = default_socket_factory


class UnexpectedSource(dns.exception.DNSException):
    """A DNS query response came from an unexpected address or port."""


class BadResponse(dns.exception.FormError):
    """A DNS query response does not respond to the question asked."""


class NoDOH(dns.exception.DNSException):
    """DNS over HTTPS (DOH) was requested but the httpx module is not
    available."""


class NoDOQ(dns.exception.DNSException):
    """DNS over QUIC (DOQ) was requested but the aioquic module is not
    available."""


# for backwards compatibility
TransferError = dns.xfr.TransferError


def _compute_times(timeout):
    now = time.time()
    if timeout is None:
        return (now, None)
    else:
        return (now, now + timeout)


def _wait_for(fd, readable, writable, _, expiration):
    # Use the selected selector class to wait for any of the specified
    # events.  An "expiration" absolute time is converted into a relative
    # timeout.
    #
    # The unused parameter is 'error', which is always set when
    # selecting for read or write, and we have no error-only selects.

    if readable and isinstance(fd, ssl.SSLSocket) and fd.pending() > 0:
        return True
    with selectors.DefaultSelector() as sel:
        events = 0
        if readable:
            events |= selectors.EVENT_READ
        if writable:
            events |= selectors.EVENT_WRITE
        if events:
            sel.register(fd, events)  # pyright: ignore
        if expiration is None:
            timeout = None
        else:
            timeout = expiration - time.time()
            if timeout <= 0.0:
                raise dns.exception.Timeout
        if not sel.select(timeout):
            raise dns.exception.Timeout


def _wait_for_readable(s, expiration):
    _wait_for(s, True, False, True, expiration)


def _wait_for_writable(s, expiration):
    _wait_for(s, False, True, True, expiration)


def _addresses_equal(af, a1, a2):
    # Convert the first value of the tuple, which is a textual format
    # address into binary form, so that we are not confused by different
    # textual representations of the same address
    try:
        n1 = dns.inet.inet_pton(af, a1[0])
        n2 = dns.inet.inet_pton(af, a2[0])
    except dns.exception.SyntaxError:
        return False
    return n1 == n2 and a1[1:] == a2[1:]


def _matches_destination(af, from_address, destination, ignore_unexpected):
    # Check that from_address is appropriate for a response to a query
    # sent to destination.
    if not destination:
        return True
    if _addresses_equal(af, from_address, destination) or (
        dns.inet.is_multicast(destination[0]) and from_address[1:] == destination[1:]
    ):
        return True
    elif ignore_unexpected:
        return False
    raise UnexpectedSource(
        f"got a response from {from_address} instead of " f"{destination}"
    )


def _destination_and_source(
    where, port, source, source_port, where_must_be_address=True
):
    # Apply defaults and compute destination and source tuples
    # suitable for use in connect(), sendto(), or bind().
    af = None
    destination = None
    try:
        af = dns.inet.af_for_address(where)
        destination = where
    except Exception:
        if where_must_be_address:
            raise
        # URLs are ok so eat the exception
    if source:
        saf = dns.inet.af_for_address(source)
        if af:
            # We know the destination af, so source had better agree!
            if saf != af:
                raise ValueError(
                    "different address families for source and destination"
                )
        else:
            # We didn't know the destination af, but we know the source,
            # so that's our af.
            af = saf
    if source_port and not source:
        # Caller has specified a source_port but not an address, so we
        # need to return a source, and we need to use the appropriate
        # wildcard address as the address.
        try:
            source = dns.inet.any_for_af(af)
        except Exception:
            # we catch this and raise ValueError for backwards compatibility
            raise ValueError("source_port specified but address family is unknown")
    # Convert high-level (address, port) tuples into low-level address
    # tuples.
    if destination:
        destination = dns.inet.low_level_address_tuple((destination, port), af)
    if source:
        source = dns.inet.low_level_address_tuple((source, source_port), af)
    return (af, destination, source)


def make_socket(
    af: socket.AddressFamily | int,
    type: socket.SocketKind,
    source: Any | None = None,
) -> socket.socket:
    """Make a socket.

    This function uses the module's ``socket_factory`` to make a socket of the
    specified address family and type.

    *af*, a ``socket.AddressFamily`` or ``int`` is the address family, either
    ``socket.AF_INET`` or ``socket.AF_INET6``.

    *type*, a ``socket.SocketKind`` is the type of socket, e.g. ``socket.SOCK_DGRAM``,
    a datagram socket, or ``socket.SOCK_STREAM``, a stream socket.  Note that the
    ``proto`` attribute of a socket is always zero with this API, so a datagram socket
    will always be a UDP socket, and a stream socket will always be a TCP socket.

    *source* is the source address and port to bind to, if any.  The default is
    ``None`` which will bind to the wildcard address and a randomly chosen port.
    If not ``None``, it should be a (low-level) address tuple appropriate for *af*.
    """
    s = socket_factory(af, type, 0)
    try:
        s.setblocking(False)
        if source is not None:
            s.bind(source)
        return s
    except Exception:
        s.close()
        raise


def make_ssl_socket(
    af: socket.AddressFamily | int,
    type: socket.SocketKind,
    ssl_context: ssl.SSLContext,
    server_hostname: dns.name.Name | str | None = None,
    source: Any | None = None,
) -> ssl.SSLSocket:
    """Make a socket.

    This function uses the module's ``socket_factory`` to make a socket of the
    specified address family and type.

    *af*, a ``socket.AddressFamily`` or ``int`` is the address family, either
    ``socket.AF_INET`` or ``socket.AF_INET6``.

    *type*, a ``socket.SocketKind`` is the type of socket, e.g. ``socket.SOCK_DGRAM``,
    a datagram socket, or ``socket.SOCK_STREAM``, a stream socket.  Note that the
    ``proto`` attribute of a socket is always zero with this API, so a datagram socket
    will always be a UDP socket, and a stream socket will always be a TCP socket.

    If *ssl_context* is not ``None``, then it specifies the SSL context to use,
    typically created with ``make_ssl_context()``.

    If *server_hostname* is not ``None``, then it is the hostname to use for server
    certificate validation.  A valid hostname must be supplied if *ssl_context*
    requires hostname checking.

    *source* is the source address and port to bind to, if any.  The default is
    ``None`` which will bind to the wildcard address and a randomly chosen port.
    If not ``None``, it should be a (low-level) address tuple appropriate for *af*.
    """
    sock = make_socket(af, type, source)
    if isinstance(server_hostname, dns.name.Name):
        server_hostname = server_hostname.to_text()
    # LGTM gets a false positive here, as our default context is OK
    return ssl_context.wrap_socket(
        sock,
        do_handshake_on_connect=False,  # lgtm[py/insecure-protocol]
        server_hostname=server_hostname,
    )


# for backwards compatibility
def _make_socket(
    af,
    type,
    source,
    ssl_context,
    server_hostname,
):
    if ssl_context is not None:
        return make_ssl_socket(af, type, ssl_context, server_hostname, source)
    else:
        return make_socket(af, type, source)


def _maybe_get_resolver(
    resolver: Optional["dns.resolver.Resolver"],  # pyright: ignore
) -> "dns.resolver.Resolver":  # pyright: ignore
    # We need a separate method for this to avoid overriding the global
    # variable "dns" with the as-yet undefined local variable "dns"
    # in https().
    if resolver is None:
        # pylint: disable=import-outside-toplevel,redefined-outer-name
        import dns.resolver

        resolver = dns.resolver.Resolver()
    return resolver


class HTTPVersion(enum.IntEnum):
    """Which version of HTTP should be used?

    DEFAULT will select the first version from the list [2, 1.1, 3] that
    is available.
    """

    DEFAULT = 0
    HTTP_1 = 1
    H1 = 1
    HTTP_2 = 2
    H2 = 2
    HTTP_3 = 3
    H3 = 3


def https(
    q: dns.message.Message,
    where: str,
    timeout: float | None = None,
    port: int = 443,
    source: str | None = None,
    source_port: int = 0,
    one_rr_per_rrset: bool = False,
    ignore_trailing: bool = False,
    session: Any | None = None,
    path: str = "/dns-query",
    post: bool = True,
    bootstrap_address: str | None = None,
    verify: bool | str | ssl.SSLContext = True,
    resolver: Optional["dns.resolver.Resolver"] = None,  # pyright: ignore
    family: int = socket.AF_UNSPEC,
    http_version: HTTPVersion = HTTPVersion.DEFAULT,
) -> dns.message.Message:
    """Return the response obtained after sending a query via DNS-over-HTTPS.

    *q*, a ``dns.message.Message``, the query to send.

    *where*, a ``str``, the nameserver IP address or the full URL. If an IP address is
    given, the URL will be constructed using the following schema:
    https://<IP-address>:<port>/<path>.

    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
    times out. If ``None``, the default, wait forever.

    *port*, a ``int``, the port to send the query to. The default is 443.

    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
    address.  The default is the wildcard address.

    *source_port*, an ``int``, the port from which to send the message. The default is
    0.

    *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset.

    *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
    received message.

    *session*, an ``httpx.Client``.  If provided, the client session to use to send the
    queries.

    *path*, a ``str``. If *where* is an IP address, then *path* will be used to
    construct the URL to send the DNS query to.

    *post*, a ``bool``. If ``True``, the default, POST method will be used.

    *bootstrap_address*, a ``str``, the IP address to use to bypass resolution.

    *verify*, a ``bool`` or ``str``.  If a ``True``, then TLS certificate verification
    of the server is done using the default CA bundle; if ``False``, then no
    verification is done; if a `str` then it specifies the path to a certificate file or
    directory which will be used for verification.

    *resolver*, a ``dns.resolver.Resolver`` or ``None``, the resolver to use for
    resolution of hostnames in URLs.  If not specified, a new resolver with a default
    configuration will be used; note this is *not* the default resolver as that resolver
    might have been configured to use DoH causing a chicken-and-egg problem.  This
    parameter only has an effect if the HTTP library is httpx.

    *family*, an ``int``, the address family.  If socket.AF_UNSPEC (the default), both A
    and AAAA records will be retrieved.

    *http_version*, a ``dns.query.HTTPVersion``, indicating which HTTP version to use.

    Returns a ``dns.message.Message``.
    """

    (af, _, the_source) = _destination_and_source(
        where, port, source, source_port, False
    )
    # we bind url and then override as pyright can't figure out all paths bind.
    url = where
    if af is not None and dns.inet.is_address(where):
        if af == socket.AF_INET:
            url = f"https://{where}:{port}{path}"
        elif af == socket.AF_INET6:
            url = f"https://[{where}]:{port}{path}"

    extensions = {}
    if bootstrap_address is None:
        # pylint: disable=possibly-used-before-assignment
        parsed = urllib.parse.urlparse(url)
        if parsed.hostname is None:
            raise ValueError("no hostname in URL")
        if dns.inet.is_address(parsed.hostname):
            bootstrap_address = parsed.hostname
            extensions["sni_hostname"] = parsed.hostname
        if parsed.port is not None:
            port = parsed.port

    if http_version == HTTPVersion.H3 or (
        http_version == HTTPVersion.DEFAULT and not have_doh
    ):
        if bootstrap_address is None:
            resolver = _maybe_get_resolver(resolver)
            assert parsed.hostname is not None  # pyright: ignore
            answers = resolver.resolve_name(parsed.hostname, family)  # pyright: ignore
            bootstrap_address = random.choice(list(answers.addresses()))
        if session and not isinstance(
            session, dns.quic.SyncQuicConnection
        ):  # pyright: ignore
            raise ValueError("session parameter must be a dns.quic.SyncQuicConnection.")
        return _http3(
            q,
            bootstrap_address,
            url,  # pyright: ignore
            timeout,
            port,
            source,
            source_port,
            one_rr_per_rrset,
            ignore_trailing,
            verify=verify,
            post=post,
            connection=session,
        )

    if not have_doh:
        raise NoDOH  # pragma: no cover
    if session and not isinstance(session, httpx.Client):  # pyright: ignore
        raise ValueError("session parameter must be an httpx.Client")

    wire = q.to_wire()
    headers = {"accept": "application/dns-message"}

    h1 = http_version in (HTTPVersion.H1, HTTPVersion.DEFAULT)
    h2 = http_version in (HTTPVersion.H2, HTTPVersion.DEFAULT)

    # set source port and source address

    if the_source is None:
        local_address = None
        local_port = 0
    else:
        local_address = the_source[0]
        local_port = the_source[1]

    if session:
        cm: contextlib.AbstractContextManager = contextlib.nullcontext(session)
    else:
        transport = _HTTPTransport(
            local_address=local_address,
            http1=h1,
            http2=h2,
            verify=verify,
            local_port=local_port,
            bootstrap_address=bootstrap_address,
            resolver=resolver,
            family=family,  # pyright: ignore
        )

        cm = httpx.Client(  # type: ignore
            http1=h1, http2=h2, verify=verify, transport=transport  # type: ignore
        )
    with cm as session:
        # see https://tools.ietf.org/html/rfc8484#section-4.1.1 for DoH
        # GET and POST examples
        assert session is not None
        if post:
            headers.update(
                {
                    "content-type": "application/dns-message",
                    "content-length": str(len(wire)),
                }
            )
            response = session.post(
                url,
                headers=headers,
                content=wire,
                timeout=timeout,
                extensions=extensions,
            )
        else:
            wire = base64.urlsafe_b64encode(wire).rstrip(b"=")
            twire = wire.decode()  # httpx does a repr() if we give it bytes
            response = session.get(
                url,
                headers=headers,
                timeout=timeout,
                params={"dns": twire},
                extensions=extensions,
            )

    # see https://tools.ietf.org/html/rfc8484#section-4.2.1 for info about DoH
    # status codes
    if response.status_code < 200 or response.status_code > 299:
        raise ValueError(
            f"{where} responded with status code {response.status_code}"
            f"\nResponse body: {response.content}"
        )
    r = dns.message.from_wire(
        response.content,
        keyring=q.keyring,
        request_mac=q.request_mac,
        one_rr_per_rrset=one_rr_per_rrset,
        ignore_trailing=ignore_trailing,
    )
    r.time = response.elapsed.total_seconds()
    if not q.is_response(r):
        raise BadResponse
    return r


def _find_header(headers: dns.quic.Headers, name: bytes) -> bytes:
    if headers is None:
        raise KeyError
    for header, value in headers:
        if header == name:
            return value
    raise KeyError


def _check_status(headers: dns.quic.Headers, peer: str, wire: bytes) -> None:
    value = _find_header(headers, b":status")
    if value is None:
        raise SyntaxError("no :status header in response")
    status = int(value)
    if status < 0:
        raise SyntaxError("status is negative")
    if status < 200 or status > 299:
        error = ""
        if len(wire) > 0:
            try:
                error = ": " + wire.decode()
            except Exception:
                pass
        raise ValueError(f"{peer} responded with status code {status}{error}")


def _http3(
    q: dns.message.Message,
    where: str,
    url: str,
    timeout: float | None = None,
    port: int = 443,
    source: str | None = None,
    source_port: int = 0,
    one_rr_per_rrset: bool = False,
    ignore_trailing: bool = False,
    verify: bool | str | ssl.SSLContext = True,
    post: bool = True,
    connection: dns.quic.SyncQuicConnection | None = None,
) -> dns.message.Message:
    if not dns.quic.have_quic:
        raise NoDOH("DNS-over-HTTP3 is not available.")  # pragma: no cover

    url_parts = urllib.parse.urlparse(url)
    hostname = url_parts.hostname
    assert hostname is not None
    if url_parts.port is not None:
        port = url_parts.port

    q.id = 0
    wire = q.to_wire()
    the_connection: dns.quic.SyncQuicConnection
    the_manager: dns.quic.SyncQuicManager
    if connection:
        manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
    else:
        manager = dns.quic.SyncQuicManager(
            verify_mode=verify, server_name=hostname, h3=True  # pyright: ignore
        )
        the_manager = manager  # for type checking happiness

    with manager:
        if connection:
            the_connection = connection
        else:
            the_connection = the_manager.connect(  # pyright: ignore
                where, port, source, source_port
            )
        (start, expiration) = _compute_times(timeout)
        with the_connection.make_stream(timeout) as stream:  # pyright: ignore
            stream.send_h3(url, wire, post)
            wire = stream.receive(_remaining(expiration))
            _check_status(stream.headers(), where, wire)
        finish = time.time()
    r = dns.message.from_wire(
        wire,
        keyring=q.keyring,
        request_mac=q.request_mac,
        one_rr_per_rrset=one_rr_per_rrset,
        ignore_trailing=ignore_trailing,
    )
    r.time = max(finish - start, 0.0)
    if not q.is_response(r):
        raise BadResponse
    return r


def _udp_recv(sock, max_size, expiration):
    """Reads a datagram from the socket.
    A Timeout exception will be raised if the operation is not completed
    by the expiration time.
    """
    while True:
        try:
            return sock.recvfrom(max_size)
        except BlockingIOError:
            _wait_for_readable(sock, expiration)


def _udp_send(sock, data, destination, expiration):
    """Sends the specified datagram to destination over the socket.
    A Timeout exception will be raised if the operation is not completed
    by the expiration time.
    """
    while True:
        try:
            if destination:
                return sock.sendto(data, destination)
            else:
                return sock.send(data)
        except BlockingIOError:  # pragma: no cover
            _wait_for_writable(sock, expiration)


def send_udp(
    sock: Any,
    what: dns.message.Message | bytes,
    destination: Any,
    expiration: float | None = None,
) -> Tuple[int, float]:
    """Send a DNS message to the specified UDP socket.

    *sock*, a ``socket``.

    *what*, a ``bytes`` or ``dns.message.Message``, the message to send.

    *destination*, a destination tuple appropriate for the address family
    of the socket, specifying where to send the query.

    *expiration*, a ``float`` or ``None``, the absolute time at which
    a timeout exception should be raised.  If ``None``, no timeout will
    occur.

    Returns an ``(int, float)`` tuple of bytes sent and the sent time.
    """

    if isinstance(what, dns.message.Message):
        what = what.to_wire()
    sent_time = time.time()
    n = _udp_send(sock, what, destination, expiration)
    return (n, sent_time)


def receive_udp(
    sock: Any,
    destination: Any | None = None,
    expiration: float | None = None,
    ignore_unexpected: bool = False,
    one_rr_per_rrset: bool = False,
    keyring: Dict[dns.name.Name, dns.tsig.Key] | None = None,
    request_mac: bytes | None = b"",
    ignore_trailing: bool = False,
    raise_on_truncation: bool = False,
    ignore_errors: bool = False,
    query: dns.message.Message | None = None,
) -> Any:
    """Read a DNS message from a UDP socket.

    *sock*, a ``socket``.

    *destination*, a destination tuple appropriate for the address family
    of the socket, specifying where the message is expected to arrive from.
    When receiving a response, this would be where the associated query was
    sent.

    *expiration*, a ``float`` or ``None``, the absolute time at which
    a timeout exception should be raised.  If ``None``, no timeout will
    occur.

    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
    unexpected sources.

    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
    RRset.

    *keyring*, a ``dict``, the keyring to use for TSIG.

    *request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG).

    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
    junk at end of the received message.

    *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
    the TC bit is set.

    Raises if the message is malformed, if network errors occur, of if
    there is a timeout.

    If *destination* is not ``None``, returns a ``(dns.message.Message, float)``
    tuple of the received message and the received time.

    If *destination* is ``None``, returns a
    ``(dns.message.Message, float, tuple)``
    tuple of the received message, the received time, and the address where
    the message arrived from.

    *ignore_errors*, a ``bool``.  If various format errors or response
    mismatches occur, ignore them and keep listening for a valid response.
    The default is ``False``.

    *query*, a ``dns.message.Message`` or ``None``.  If not ``None`` and
    *ignore_errors* is ``True``, check that the received message is a response
    to this query, and if not keep listening for a valid response.
    """

    wire = b""
    while True:
        (wire, from_address) = _udp_recv(sock, 65535, expiration)
        if not _matches_destination(
            sock.family, from_address, destination, ignore_unexpected
        ):
            continue
        received_time = time.time()
        try:
            r = dns.message.from_wire(
                wire,
                keyring=keyring,
                request_mac=request_mac,
                one_rr_per_rrset=one_rr_per_rrset,
                ignore_trailing=ignore_trailing,
                raise_on_truncation=raise_on_truncation,
            )
        except dns.message.Truncated as e:
            # If we got Truncated and not FORMERR, we at least got the header with TC
            # set, and very likely the question section, so we'll re-raise if the
            # message seems to be a response as we need to know when truncation happens.
            # We need to check that it seems to be a response as we don't want a random
            # injected message with TC set to cause us to bail out.
            if (
                ignore_errors
                and query is not None
                and not query.is_response(e.message())
            ):
                continue
            else:
                raise
        except Exception:
            if ignore_errors:
                continue
            else:
                raise
        if ignore_errors and query is not None and not query.is_response(r):
            continue
        if destination:
            return (r, received_time)
        else:
            return (r, received_time, from_address)


def udp(
    q: dns.message.Message,
    where: str,
    timeout: float | None = None,
    port: int = 53,
    source: str | None = None,
    source_port: int = 0,
    ignore_unexpected: bool = False,
    one_rr_per_rrset: bool = False,
    ignore_trailing: bool = False,
    raise_on_truncation: bool = False,
    sock: Any | None = None,
    ignore_errors: bool = False,
) -> dns.message.Message:
    """Return the response obtained after sending a query via UDP.

    *q*, a ``dns.message.Message``, the query to send

    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
    to send the message.

    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
    query times out.  If ``None``, the default, wait forever.

    *port*, an ``int``, the port send the message to.  The default is 53.

    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
    the source address.  The default is the wildcard address.

    *source_port*, an ``int``, the port from which to send the message.
    The default is 0.

    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from
    unexpected sources.

    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
    RRset.

    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
    junk at end of the received message.

    *raise_on_truncation*, a ``bool``.  If ``True``, raise an exception if
    the TC bit is set.

    *sock*, a ``socket.socket``, or ``None``, the socket to use for the
    query.  If ``None``, the default, a socket is created.  Note that
    if a socket is provided, it must be a nonblocking datagram socket,
    and the *source* and *source_port* are ignored.

    *ignore_errors*, a ``bool``.  If various format errors or response
    mismatches occur, ignore them and keep listening for a valid response.
    The default is ``False``.

    Returns a ``dns.message.Message``.
    """

    wire = q.to_wire()
    (af, destination, source) = _destination_and_source(
        where, port, source, source_port, True
    )
    (begin_time, expiration) = _compute_times(timeout)
    if sock:
        cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock)
    else:
        assert af is not None
        cm = make_socket(af, socket.SOCK_DGRAM, source)
    with cm as s:
        send_udp(s, wire, destination, expiration)
        (r, received_time) = receive_udp(
            s,
            destination,
            expiration,
            ignore_unexpected,
            one_rr_per_rrset,
            q.keyring,
            q.mac,
            ignore_trailing,
            raise_on_truncation,
            ignore_errors,
            q,
        )
        r.time = received_time - begin_time
        # We don't need to check q.is_response() if we are in ignore_errors mode
        # as receive_udp() will have checked it.
        if not (ignore_errors or q.is_response(r)):
            raise BadResponse
        return r
    assert (
        False  # help mypy figure out we can't get here  lgtm[py/unreachable-statement]
    )


def udp_with_fallback(
    q: dns.message.Message,
    where: str,
    timeout: float | None = None,
    port: int = 53,
    source: str | None = None,
    source_port: int = 0,
    ignore_unexpected: bool = False,
    one_rr_per_rrset: bool = False,
    ignore_trailing: bool = False,
    udp_sock: Any | None = None,
    tcp_sock: Any | None = None,
    ignore_errors: bool = False,
) -> Tuple[dns.message.Message, bool]:
    """Return the response to the query, trying UDP first and falling back
    to TCP if UDP results in a truncated response.

    *q*, a ``dns.message.Message``, the query to send

    *where*, a ``str`` containing an IPv4 or IPv6 address,  where to send the message.

    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
    times out.  If ``None``, the default, wait forever.

    *port*, an ``int``, the port send the message to.  The default is 53.

    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
    address.  The default is the wildcard address.

    *source_port*, an ``int``, the port from which to send the message. The default is
    0.

    *ignore_unexpected*, a ``bool``.  If ``True``, ignore responses from unexpected
    sources.

    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own RRset.

    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing junk at end of the
    received message.

    *udp_sock*, a ``socket.socket``, or ``None``, the socket to use for the UDP query.
    If ``None``, the default, a socket is created.  Note that if a socket is provided,
    it must be a nonblocking datagram socket, and the *source* and *source_port* are
    ignored for the UDP query.

    *tcp_sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
    TCP query.  If ``None``, the default, a socket is created.  Note that if a socket is
    provided, it must be a nonblocking connected stream socket, and *where*, *source*
    and *source_port* are ignored for the TCP query.

    *ignore_errors*, a ``bool``.  If various format errors or response mismatches occur
    while listening for UDP, ignore them and keep listening for a valid response. The
    default is ``False``.

    Returns a (``dns.message.Message``, tcp) tuple where tcp is ``True`` if and only if
    TCP was used.
    """
    try:
        response = udp(
            q,
            where,
            timeout,
            port,
            source,
            source_port,
            ignore_unexpected,
            one_rr_per_rrset,
            ignore_trailing,
            True,
            udp_sock,
            ignore_errors,
        )
        return (response, False)
    except dns.message.Truncated:
        response = tcp(
            q,
            where,
            timeout,
            port,
            source,
            source_port,
            one_rr_per_rrset,
            ignore_trailing,
            tcp_sock,
        )
        return (response, True)


def _net_read(sock, count, expiration):
    """Read the specified number of bytes from sock.  Keep trying until we
    either get the desired amount, or we hit EOF.
    A Timeout exception will be raised if the operation is not completed
    by the expiration time.
    """
    s = b""
    while count > 0:
        try:
            n = sock.recv(count)
            if n == b"":
                raise EOFError("EOF")
            count -= len(n)
            s += n
        except (BlockingIOError, ssl.SSLWantReadError):
            _wait_for_readable(sock, expiration)
        except ssl.SSLWantWriteError:  # pragma: no cover
            _wait_for_writable(sock, expiration)
    return s


def _net_write(sock, data, expiration):
    """Write the specified data to the socket.
    A Timeout exception will be raised if the operation is not completed
    by the expiration time.
    """
    current = 0
    l = len(data)
    while current < l:
        try:
            current += sock.send(data[current:])
        except (BlockingIOError, ssl.SSLWantWriteError):
            _wait_for_writable(sock, expiration)
        except ssl.SSLWantReadError:  # pragma: no cover
            _wait_for_readable(sock, expiration)


def send_tcp(
    sock: Any,
    what: dns.message.Message | bytes,
    expiration: float | None = None,
) -> Tuple[int, float]:
    """Send a DNS message to the specified TCP socket.

    *sock*, a ``socket``.

    *what*, a ``bytes`` or ``dns.message.Message``, the message to send.

    *expiration*, a ``float`` or ``None``, the absolute time at which
    a timeout exception should be raised.  If ``None``, no timeout will
    occur.

    Returns an ``(int, float)`` tuple of bytes sent and the sent time.
    """

    if isinstance(what, dns.message.Message):
        tcpmsg = what.to_wire(prepend_length=True)
    else:
        # copying the wire into tcpmsg is inefficient, but lets us
        # avoid writev() or doing a short write that would get pushed
        # onto the net
        tcpmsg = len(what).to_bytes(2, "big") + what
    sent_time = time.time()
    _net_write(sock, tcpmsg, expiration)
    return (len(tcpmsg), sent_time)


def receive_tcp(
    sock: Any,
    expiration: float | None = None,
    one_rr_per_rrset: bool = False,
    keyring: Dict[dns.name.Name, dns.tsig.Key] | None = None,
    request_mac: bytes | None = b"",
    ignore_trailing: bool = False,
) -> Tuple[dns.message.Message, float]:
    """Read a DNS message from a TCP socket.

    *sock*, a ``socket``.

    *expiration*, a ``float`` or ``None``, the absolute time at which
    a timeout exception should be raised.  If ``None``, no timeout will
    occur.

    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
    RRset.

    *keyring*, a ``dict``, the keyring to use for TSIG.

    *request_mac*, a ``bytes`` or ``None``, the MAC of the request (for TSIG).

    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
    junk at end of the received message.

    Raises if the message is malformed, if network errors occur, of if
    there is a timeout.

    Returns a ``(dns.message.Message, float)`` tuple of the received message
    and the received time.
    """

    ldata = _net_read(sock, 2, expiration)
    (l,) = struct.unpack("!H", ldata)
    wire = _net_read(sock, l, expiration)
    received_time = time.time()
    r = dns.message.from_wire(
        wire,
        keyring=keyring,
        request_mac=request_mac,
        one_rr_per_rrset=one_rr_per_rrset,
        ignore_trailing=ignore_trailing,
    )
    return (r, received_time)


def _connect(s, address, expiration):
    err = s.connect_ex(address)
    if err == 0:
        return
    if err in (errno.EINPROGRESS, errno.EWOULDBLOCK, errno.EALREADY):
        _wait_for_writable(s, expiration)
        err = s.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR)
    if err != 0:
        raise OSError(err, os.strerror(err))


def tcp(
    q: dns.message.Message,
    where: str,
    timeout: float | None = None,
    port: int = 53,
    source: str | None = None,
    source_port: int = 0,
    one_rr_per_rrset: bool = False,
    ignore_trailing: bool = False,
    sock: Any | None = None,
) -> dns.message.Message:
    """Return the response obtained after sending a query via TCP.

    *q*, a ``dns.message.Message``, the query to send

    *where*, a ``str`` containing an IPv4 or IPv6 address, where
    to send the message.

    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
    query times out.  If ``None``, the default, wait forever.

    *port*, an ``int``, the port send the message to.  The default is 53.

    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
    the source address.  The default is the wildcard address.

    *source_port*, an ``int``, the port from which to send the message.
    The default is 0.

    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
    RRset.

    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
    junk at end of the received message.

    *sock*, a ``socket.socket``, or ``None``, the connected socket to use for the
    query.  If ``None``, the default, a socket is created.  Note that
    if a socket is provided, it must be a nonblocking connected stream
    socket, and *where*, *port*, *source* and *source_port* are ignored.

    Returns a ``dns.message.Message``.
    """

    wire = q.to_wire()
    (begin_time, expiration) = _compute_times(timeout)
    if sock:
        cm: contextlib.AbstractContextManager = contextlib.nullcontext(sock)
    else:
        (af, destination, source) = _destination_and_source(
            where, port, source, source_port, True
        )
        assert af is not None
        cm = make_socket(af, socket.SOCK_STREAM, source)
    with cm as s:
        if not sock:
            # pylint: disable=possibly-used-before-assignment
            _connect(s, destination, expiration)  # pyright: ignore
        send_tcp(s, wire, expiration)
        (r, received_time) = receive_tcp(
            s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
        )
        r.time = received_time - begin_time
        if not q.is_response(r):
            raise BadResponse
        return r
    assert (
        False  # help mypy figure out we can't get here  lgtm[py/unreachable-statement]
    )


def _tls_handshake(s, expiration):
    while True:
        try:
            s.do_handshake()
            return
        except ssl.SSLWantReadError:
            _wait_for_readable(s, expiration)
        except ssl.SSLWantWriteError:  # pragma: no cover
            _wait_for_writable(s, expiration)


def make_ssl_context(
    verify: bool | str = True,
    check_hostname: bool = True,
    alpns: list[str] | None = None,
) -> ssl.SSLContext:
    """Make an SSL context

    If *verify* is ``True``, the default, then certificate verification will occur using
    the standard CA roots.  If *verify* is ``False``, then certificate verification will
    be disabled.  If *verify* is a string which is a valid pathname, then if the
    pathname is a regular file, the CA roots will be taken from the file, otherwise if
    the pathname is a directory roots will be taken from the directory.

    If *check_hostname* is ``True``, the default, then the hostname of the server must
    be specified when connecting and the server's certificate must authorize the
    hostname.  If ``False``, then hostname checking is disabled.

    *aplns* is ``None`` or a list of TLS ALPN (Application Layer Protocol Negotiation)
    strings to use in negotiation.  For DNS-over-TLS, the right value is `["dot"]`.
    """
    cafile, capath = dns._tls_util.convert_verify_to_cafile_and_capath(verify)
    ssl_context = ssl.create_default_context(cafile=cafile, capath=capath)
    # the pyright ignores below are because it gets confused between the
    # _no_ssl compatibility types and the real ones.
    ssl_context.minimum_version = ssl.TLSVersion.TLSv1_2  # type: ignore
    ssl_context.check_hostname = check_hostname
    if verify is False:
        ssl_context.verify_mode = ssl.CERT_NONE  # type: ignore
    if alpns is not None:
        ssl_context.set_alpn_protocols(alpns)
    return ssl_context  # type: ignore


# for backwards compatibility
def _make_dot_ssl_context(
    server_hostname: str | None, verify: bool | str
) -> ssl.SSLContext:
    return make_ssl_context(verify, server_hostname is not None, ["dot"])


def tls(
    q: dns.message.Message,
    where: str,
    timeout: float | None = None,
    port: int = 853,
    source: str | None = None,
    source_port: int = 0,
    one_rr_per_rrset: bool = False,
    ignore_trailing: bool = False,
    sock: ssl.SSLSocket | None = None,
    ssl_context: ssl.SSLContext | None = None,
    server_hostname: str | None = None,
    verify: bool | str = True,
) -> dns.message.Message:
    """Return the response obtained after sending a query via TLS.

    *q*, a ``dns.message.Message``, the query to send

    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
    to send the message.

    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the
    query times out.  If ``None``, the default, wait forever.

    *port*, an ``int``, the port send the message to.  The default is 853.

    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
    the source address.  The default is the wildcard address.

    *source_port*, an ``int``, the port from which to send the message.
    The default is 0.

    *one_rr_per_rrset*, a ``bool``.  If ``True``, put each RR into its own
    RRset.

    *ignore_trailing*, a ``bool``.  If ``True``, ignore trailing
    junk at end of the received message.

    *sock*, an ``ssl.SSLSocket``, or ``None``, the socket to use for
    the query.  If ``None``, the default, a socket is created.  Note
    that if a socket is provided, it must be a nonblocking connected
    SSL stream socket, and *where*, *port*, *source*, *source_port*,
    and *ssl_context* are ignored.

    *ssl_context*, an ``ssl.SSLContext``, the context to use when establishing
    a TLS connection. If ``None``, the default, creates one with the default
    configuration.

    *server_hostname*, a ``str`` containing the server's hostname.  The
    default is ``None``, which means that no hostname is known, and if an
    SSL context is created, hostname checking will be disabled.

    *verify*, a ``bool`` or ``str``.  If a ``True``, then TLS certificate verification
    of the server is done using the default CA bundle; if ``False``, then no
    verification is done; if a `str` then it specifies the path to a certificate file or
    directory which will be used for verification.

    Returns a ``dns.message.Message``.

    """

    if sock:
        #
        # If a socket was provided, there's no special TLS handling needed.
        #
        return tcp(
            q,
            where,
            timeout,
            port,
            source,
            source_port,
            one_rr_per_rrset,
            ignore_trailing,
            sock,
        )

    wire = q.to_wire()
    (begin_time, expiration) = _compute_times(timeout)
    (af, destination, source) = _destination_and_source(
        where, port, source, source_port, True
    )
    assert af is not None  # where must be an address
    if ssl_context is None:
        ssl_context = make_ssl_context(verify, server_hostname is not None, ["dot"])

    with make_ssl_socket(
        af,
        socket.SOCK_STREAM,
        ssl_context=ssl_context,
        server_hostname=server_hostname,
        source=source,
    ) as s:
        _connect(s, destination, expiration)
        _tls_handshake(s, expiration)
        send_tcp(s, wire, expiration)
        (r, received_time) = receive_tcp(
            s, expiration, one_rr_per_rrset, q.keyring, q.mac, ignore_trailing
        )
        r.time = received_time - begin_time
        if not q.is_response(r):
            raise BadResponse
        return r
    assert (
        False  # help mypy figure out we can't get here  lgtm[py/unreachable-statement]
    )


def quic(
    q: dns.message.Message,
    where: str,
    timeout: float | None = None,
    port: int = 853,
    source: str | None = None,
    source_port: int = 0,
    one_rr_per_rrset: bool = False,
    ignore_trailing: bool = False,
    connection: dns.quic.SyncQuicConnection | None = None,
    verify: bool | str = True,
    hostname: str | None = None,
    server_hostname: str | None = None,
) -> dns.message.Message:
    """Return the response obtained after sending a query via DNS-over-QUIC.

    *q*, a ``dns.message.Message``, the query to send.

    *where*, a ``str``, the nameserver IP address.

    *timeout*, a ``float`` or ``None``, the number of seconds to wait before the query
    times out. If ``None``, the default, wait forever.

    *port*, a ``int``, the port to send the query to. The default is 853.

    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying the source
    address.  The default is the wildcard address.

    *source_port*, an ``int``, the port from which to send the message. The default is
    0.

    *one_rr_per_rrset*, a ``bool``. If ``True``, put each RR into its own RRset.

    *ignore_trailing*, a ``bool``. If ``True``, ignore trailing junk at end of the
    received message.

    *connection*, a ``dns.quic.SyncQuicConnection``.  If provided, the connection to use
    to send the query.

    *verify*, a ``bool`` or ``str``.  If a ``True``, then TLS certificate verification
    of the server is done using the default CA bundle; if ``False``, then no
    verification is done; if a `str` then it specifies the path to a certificate file or
    directory which will be used for verification.

    *hostname*, a ``str`` containing the server's hostname or ``None``.  The default is
    ``None``, which means that no hostname is known, and if an SSL context is created,
    hostname checking will be disabled.  This value is ignored if *url* is not
    ``None``.

    *server_hostname*, a ``str`` or ``None``.  This item is for backwards compatibility
    only, and has the same meaning as *hostname*.

    Returns a ``dns.message.Message``.
    """

    if not dns.quic.have_quic:
        raise NoDOQ("DNS-over-QUIC is not available.")  # pragma: no cover

    if server_hostname is not None and hostname is None:
        hostname = server_hostname

    q.id = 0
    wire = q.to_wire()
    the_connection: dns.quic.SyncQuicConnection
    the_manager: dns.quic.SyncQuicManager
    if connection:
        manager: contextlib.AbstractContextManager = contextlib.nullcontext(None)
        the_connection = connection
    else:
        manager = dns.quic.SyncQuicManager(
            verify_mode=verify, server_name=hostname  # pyright: ignore
        )
        the_manager = manager  # for type checking happiness

    with manager:
        if not connection:
            the_connection = the_manager.connect(  # pyright: ignore
                where, port, source, source_port
            )
        (start, expiration) = _compute_times(timeout)
        with the_connection.make_stream(timeout) as stream:  # pyright: ignore
            stream.send(wire, True)
            wire = stream.receive(_remaining(expiration))
        finish = time.time()
    r = dns.message.from_wire(
        wire,
        keyring=q.keyring,
        request_mac=q.request_mac,
        one_rr_per_rrset=one_rr_per_rrset,
        ignore_trailing=ignore_trailing,
    )
    r.time = max(finish - start, 0.0)
    if not q.is_response(r):
        raise BadResponse
    return r


class UDPMode(enum.IntEnum):
    """How should UDP be used in an IXFR from :py:func:`inbound_xfr()`?

    NEVER means "never use UDP; always use TCP"
    TRY_FIRST means "try to use UDP but fall back to TCP if needed"
    ONLY means "raise ``dns.xfr.UseTCP`` if trying UDP does not succeed"
    """

    NEVER = 0
    TRY_FIRST = 1
    ONLY = 2


def _inbound_xfr(
    txn_manager: dns.transaction.TransactionManager,
    s: socket.socket | ssl.SSLSocket,
    query: dns.message.Message,
    serial: int | None,
    timeout: float | None,
    expiration: float | None,
) -> Any:
    """Given a socket, does the zone transfer."""
    rdtype = query.question[0].rdtype
    is_ixfr = rdtype == dns.rdatatype.IXFR
    origin = txn_manager.from_wire_origin()
    wire = query.to_wire()
    is_udp = isinstance(s, socket.socket) and s.type == socket.SOCK_DGRAM
    if is_udp:
        _udp_send(s, wire, None, expiration)
    else:
        tcpmsg = struct.pack("!H", len(wire)) + wire
        _net_write(s, tcpmsg, expiration)
    with dns.xfr.Inbound(txn_manager, rdtype, serial, is_udp) as inbound:
        done = False
        tsig_ctx = None
        r: dns.message.Message | None = None
        while not done:
            (_, mexpiration) = _compute_times(timeout)
            if mexpiration is None or (
                expiration is not None and mexpiration > expiration
            ):
                mexpiration = expiration
            if is_udp:
                (rwire, _) = _udp_recv(s, 65535, mexpiration)
            else:
                ldata = _net_read(s, 2, mexpiration)
                (l,) = struct.unpack("!H", ldata)
                rwire = _net_read(s, l, mexpiration)
            r = dns.message.from_wire(
                rwire,
                keyring=query.keyring,
                request_mac=query.mac,
                xfr=True,
                origin=origin,
                tsig_ctx=tsig_ctx,
                multi=(not is_udp),
                one_rr_per_rrset=is_ixfr,
            )
            done = inbound.process_message(r)
            yield r
            tsig_ctx = r.tsig_ctx
        if query.keyring and r is not None and not r.had_tsig:
            raise dns.exception.FormError("missing TSIG")


def xfr(
    where: str,
    zone: dns.name.Name | str,
    rdtype: dns.rdatatype.RdataType | str = dns.rdatatype.AXFR,
    rdclass: dns.rdataclass.RdataClass | str = dns.rdataclass.IN,
    timeout: float | None = None,
    port: int = 53,
    keyring: Dict[dns.name.Name, dns.tsig.Key] | None = None,
    keyname: dns.name.Name | str | None = None,
    relativize: bool = True,
    lifetime: float | None = None,
    source: str | None = None,
    source_port: int = 0,
    serial: int = 0,
    use_udp: bool = False,
    keyalgorithm: dns.name.Name | str = dns.tsig.default_algorithm,
) -> Any:
    """Return a generator for the responses to a zone transfer.

    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
    to send the message.

    *zone*, a ``dns.name.Name`` or ``str``, the name of the zone to transfer.

    *rdtype*, an ``int`` or ``str``, the type of zone transfer.  The
    default is ``dns.rdatatype.AXFR``.  ``dns.rdatatype.IXFR`` can be
    used to do an incremental transfer instead.

    *rdclass*, an ``int`` or ``str``, the class of the zone transfer.
    The default is ``dns.rdataclass.IN``.

    *timeout*, a ``float``, the number of seconds to wait for each
    response message.  If None, the default, wait forever.

    *port*, an ``int``, the port send the message to.  The default is 53.

    *keyring*, a ``dict``, the keyring to use for TSIG.

    *keyname*, a ``dns.name.Name`` or ``str``, the name of the TSIG
    key to use.

    *relativize*, a ``bool``.  If ``True``, all names in the zone will be
    relativized to the zone origin.  It is essential that the
    relativize setting matches the one specified to
    ``dns.zone.from_xfr()`` if using this generator to make a zone.

    *lifetime*, a ``float``, the total number of seconds to spend
    doing the transfer.  If ``None``, the default, then there is no
    limit on the time the transfer may take.

    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
    the source address.  The default is the wildcard address.

    *source_port*, an ``int``, the port from which to send the message.
    The default is 0.

    *serial*, an ``int``, the SOA serial number to use as the base for
    an IXFR diff sequence (only meaningful if *rdtype* is
    ``dns.rdatatype.IXFR``).

    *use_udp*, a ``bool``.  If ``True``, use UDP (only meaningful for IXFR).

    *keyalgorithm*, a ``dns.name.Name`` or ``str``, the TSIG algorithm to use.

    Raises on errors, and so does the generator.

    Returns a generator of ``dns.message.Message`` objects.
    """

    class DummyTransactionManager(dns.transaction.TransactionManager):
        def __init__(self, origin, relativize):
            self.info = (origin, relativize, dns.name.empty if relativize else origin)

        def origin_information(self):
            return self.info

        def get_class(self) -> dns.rdataclass.RdataClass:
            raise NotImplementedError  # pragma: no cover

        def reader(self):
            raise NotImplementedError  # pragma: no cover

        def writer(self, replacement: bool = False) -> dns.transaction.Transaction:
            class DummyTransaction:
                def nop(self, *args, **kw):
                    pass

                def __getattr__(self, _):
                    return self.nop

            return cast(dns.transaction.Transaction, DummyTransaction())

    if isinstance(zone, str):
        zone = dns.name.from_text(zone)
    rdtype = dns.rdatatype.RdataType.make(rdtype)
    q = dns.message.make_query(zone, rdtype, rdclass)
    if rdtype == dns.rdatatype.IXFR:
        rrset = q.find_rrset(
            q.authority, zone, dns.rdataclass.IN, dns.rdatatype.SOA, create=True
        )
        soa = dns.rdata.from_text("IN", "SOA", f". . {serial} 0 0 0 0")
        rrset.add(soa, 0)
    if keyring is not None:
        q.use_tsig(keyring, keyname, algorithm=keyalgorithm)
    (af, destination, source) = _destination_and_source(
        where, port, source, source_port, True
    )
    assert af is not None
    (_, expiration) = _compute_times(lifetime)
    tm = DummyTransactionManager(zone, relativize)
    if use_udp and rdtype != dns.rdatatype.IXFR:
        raise ValueError("cannot do a UDP AXFR")
    sock_type = socket.SOCK_DGRAM if use_udp else socket.SOCK_STREAM
    with make_socket(af, sock_type, source) as s:
        _connect(s, destination, expiration)
        yield from _inbound_xfr(tm, s, q, serial, timeout, expiration)


def inbound_xfr(
    where: str,
    txn_manager: dns.transaction.TransactionManager,
    query: dns.message.Message | None = None,
    port: int = 53,
    timeout: float | None = None,
    lifetime: float | None = None,
    source: str | None = None,
    source_port: int = 0,
    udp_mode: UDPMode = UDPMode.NEVER,
) -> None:
    """Conduct an inbound transfer and apply it via a transaction from the
    txn_manager.

    *where*, a ``str`` containing an IPv4 or IPv6 address,  where
    to send the message.

    *txn_manager*, a ``dns.transaction.TransactionManager``, the txn_manager
    for this transfer (typically a ``dns.zone.Zone``).

    *query*, the query to send.  If not supplied, a default query is
    constructed using information from the *txn_manager*.

    *port*, an ``int``, the port send the message to.  The default is 53.

    *timeout*, a ``float``, the number of seconds to wait for each
    response message.  If None, the default, wait forever.

    *lifetime*, a ``float``, the total number of seconds to spend
    doing the transfer.  If ``None``, the default, then there is no
    limit on the time the transfer may take.

    *source*, a ``str`` containing an IPv4 or IPv6 address, specifying
    the source address.  The default is the wildcard address.

    *source_port*, an ``int``, the port from which to send the message.
    The default is 0.

    *udp_mode*, a ``dns.query.UDPMode``, determines how UDP is used
    for IXFRs.  The default is ``dns.query.UDPMode.NEVER``, i.e. only use
    TCP.  Other possibilities are ``dns.query.UDPMode.TRY_FIRST``, which
    means "try UDP but fallback to TCP if needed", and
    ``dns.query.UDPMode.ONLY``, which means "try UDP and raise
    ``dns.xfr.UseTCP`` if it does not succeed.

    Raises on errors.
    """
    if query is None:
        (query, serial) = dns.xfr.make_query(txn_manager)
    else:
        serial = dns.xfr.extract_serial_from_query(query)

    (af, destination, source) = _destination_and_source(
        where, port, source, source_port, True
    )
    assert af is not None
    (_, expiration) = _compute_times(lifetime)
    if query.question[0].rdtype == dns.rdatatype.IXFR and udp_mode != UDPMode.NEVER:
        with make_socket(af, socket.SOCK_DGRAM, source) as s:
            _connect(s, destination, expiration)
            try:
                for _ in _inbound_xfr(
                    txn_manager, s, query, serial, timeout, expiration
                ):
                    pass
                return
            except dns.xfr.UseTCP:
                if udp_mode == UDPMode.ONLY:
                    raise

    with make_socket(af, socket.SOCK_STREAM, source) as s:
        _connect(s, destination, expiration)
        for _ in _inbound_xfr(txn_manager, s, query, serial, timeout, expiration):
            pass
