pax_global_header00006660000000000000000000000064147421726110014517gustar00rootroot0000000000000052 comment=1bb9c53f97807fd50f5c4331a98acedfc2d2d672 aiorpcX-0.24/000077500000000000000000000000001474217261100130515ustar00rootroot00000000000000aiorpcX-0.24/.azure-pipelines/000077500000000000000000000000001474217261100162435ustar00rootroot00000000000000aiorpcX-0.24/.azure-pipelines/ci.yml000066400000000000000000000033531474217261100173650ustar00rootroot00000000000000pr: - master - releases/* jobs: - template: run-tests.yml parameters: name: Py313_Ubuntu vmImage: 'ubuntu-latest' pythonVersion: '3.13' - template: run-tests.yml parameters: name: Py313_Mac vmImage: 'macOS-latest' pythonVersion: '3.13' - template: run-tests.yml parameters: name: Py313_Win vmImage: 'windows-latest' pythonVersion: '3.13' - template: run-tests.yml parameters: name: Py312_Ubuntu vmImage: 'ubuntu-latest' pythonVersion: '3.12' - template: run-tests.yml parameters: name: Py312_Mac vmImage: 'macOS-latest' pythonVersion: '3.12' - template: run-tests.yml parameters: name: Py312_Win vmImage: 'windows-latest' pythonVersion: '3.12' - template: run-tests.yml parameters: name: Py311_Ubuntu vmImage: 'ubuntu-latest' pythonVersion: '3.11' - template: run-tests.yml parameters: name: Py311_Mac vmImage: 'macOS-latest' pythonVersion: '3.11' - template: run-tests.yml parameters: name: Py311_Win vmImage: 'windows-latest' pythonVersion: '3.11' - template: run-tests.yml parameters: name: Py310_Ubuntu vmImage: 'ubuntu-latest' pythonVersion: '3.10' - template: run-tests.yml parameters: name: Py310_Mac vmImage: 'macOS-latest' pythonVersion: '3.10' - template: run-tests.yml parameters: name: Py310_Win vmImage: 'windows-latest' pythonVersion: '3.10' - template: run-tests.yml parameters: name: Py39_Ubuntu vmImage: 'ubuntu-latest' pythonVersion: '3.9' - template: run-tests.yml parameters: name: Py39_Mac vmImage: 'macOS-latest' pythonVersion: '3.9' - template: run-tests.yml parameters: name: Py39_Win vmImage: 'windows-latest' pythonVersion: '3.9' aiorpcX-0.24/.azure-pipelines/prepare-env.yml000066400000000000000000000010361474217261100212120ustar00rootroot00000000000000parameters: onlyPullRequests: false steps: - script: | python -m pip install websockets python -m pip install uvloop python -m pip install flake8 python -m pip install coveralls coverage python -m pip install pytest pytest-asyncio Sphinx displayName: Prepare general environment condition: | and( succeeded(), or( eq(variables['Build.Reason'], 'PullRequest'), eq(${{ parameters.onlyPullRequests }}, false) ) ) enabled: true continueOnError: false failOnStderr: false aiorpcX-0.24/.azure-pipelines/run-tests.yml000066400000000000000000000017041474217261100207340ustar00rootroot00000000000000parameters: name: '' # defaults for any parameters that aren't specified vmImage: '' pythonVersion: '' jobs: - job: ${{ parameters.name }} pool: vmImage: ${{ parameters.vmImage }} steps: - task: UsePythonVersion@0 inputs: versionSpec: ${{ parameters.pythonVersion }} addToPath: true architecture: x64 - template: prepare-env.yml - script: | coverage run -m pytest --junitxml=junit/test-results.xml tests && coverage xml displayName: 'Test with pytest' - bash: flake8 displayName: flake8 - task: PublishTestResults@2 condition: succeededOrFailed() inputs: testResultsFiles: 'junit/test-*.xml' testRunTitle: 'Publish test results for Python ${{ parameters.pythonVersion }}' - task: PublishCodeCoverageResults@2 inputs: codeCoverageTool: cobertura summaryFileLocation: coverage.xml - bash: | COVERALLS_REPO_TOKEN=$(CRT) coveralls displayName: 'Coveralls' aiorpcX-0.24/.coveragerc000066400000000000000000000000301474217261100151630ustar00rootroot00000000000000[run] omit = tests/* aiorpcX-0.24/.flake8000066400000000000000000000001451474217261100142240ustar00rootroot00000000000000[flake8] max_line_length=99 exclude=aiorpcx/__init__.py build/ docs/conf.py tests/conftest.py aiorpcX-0.24/.gitignore000077500000000000000000000001471474217261100150460ustar00rootroot00000000000000**/__pycache__/ **/*~ **/*.#* .pytest_cache/ .coverage .cache/ htmlcov/ dist/ build/ aiorpcX.egg-info/ aiorpcX-0.24/LICENCE000077500000000000000000000021131474217261100140360ustar00rootroot00000000000000Copyright (c) 2018 Neil Booth All rights reserved. The MIT License (MIT) Permission is hereby granted, free of charge, to any person obtaining a copy of this software and associated documentation files (the "Software"), to deal in the Software without restriction, including without limitation the rights to use, copy, modify, merge, publish, distribute, sublicense, and/or sell copies of the Software, and to permit persons to whom the Software is furnished to do so, subject to the following conditions: The above copyright notice and this permission notice shall be included in all copies or substantial portions of the Software. THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. aiorpcX-0.24/MANIFEST.in000066400000000000000000000000571474217261100146110ustar00rootroot00000000000000include LICENCE recursive-include aiorpcx *.py aiorpcX-0.24/README.rst000077500000000000000000000012631474217261100145450ustar00rootroot00000000000000.. image:: https://badge.fury.io/py/aiorpcX.svg :target: http://badge.fury.io/py/aiorpcX .. image:: https://travis-ci.org/kyuupichan/aiorpcX.svg?branch=master :target: https://travis-ci.org/kyuupichan/aiorpcX .. image:: https://coveralls.io/repos/github/kyuupichan/aiorpcX/badge.svg :target: https://coveralls.io/github/kyuupichan/aiorpcX ======= aiorpcX ======= A generic `asyncio `_ library implementation of RPC suitable for an application that is a client, server or both. :Licence: MIT :Language: Python (>= 3.9) :Author: Neil Booth Documentation ============= See `readthedocs `_. aiorpcX-0.24/aiorpcx/000077500000000000000000000000001474217261100145165ustar00rootroot00000000000000aiorpcX-0.24/aiorpcx/__init__.py000077500000000000000000000010571474217261100166350ustar00rootroot00000000000000from .curio import * from .framing import * from .jsonrpc import * from .rawsocket import * from .socks import * from .session import * from .unixsocket import * from .util import * from .websocket import * _version_str = '0.24.0' _version = tuple(int(part) for part in _version_str.split('.')) __all__ = (curio.__all__ + framing.__all__ + jsonrpc.__all__ + rawsocket.__all__ + socks.__all__ + session.__all__ + unixsocket.__all__ + util.__all__ + websocket.__all__) aiorpcX-0.24/aiorpcx/curio.py000077500000000000000000000431301474217261100162150ustar00rootroot00000000000000# The code below is mostly my own but based on the interfaces of the # curio library by David Beazley. I'm considering switching to using # curio. In the mean-time this is an attempt to provide a similar # clean, pure-async interface and move away from direct # framework-specific dependencies. As asyncio differs in its design # it is not possible to provide identical semantics. # # The curio library is distributed under the following licence: # # Copyright (C) 2015-2017 # David Beazley (Dabeaz LLC) # All rights reserved. # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are # met: # # * Redistributions of source code must retain the above copyright notice, # this list of conditions and the following disclaimer. # * Redistributions in binary form must reproduce the above copyright notice, # this list of conditions and the following disclaimer in the documentation # and/or other materials provided with the distribution. # * Neither the name of the David Beazley or Dabeaz LLC may be used to # endorse or promote products derived from this software without # specific prior written permission. # # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS # "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT # LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR # A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT # OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, # SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT # LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY # THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. from asyncio import ( CancelledError, get_event_loop, Queue, Event, Lock, Semaphore, sleep, current_task ) from collections import deque from aiorpcx.util import instantiate_coroutine __all__ = ( 'Queue', 'Event', 'Lock', 'Semaphore', 'sleep', 'CancelledError', 'run_in_thread', 'spawn', 'spawn_sync', 'TaskGroup', 'NoRemainingTasksError', 'TaskTimeout', 'TimeoutCancellationError', 'UncaughtTimeoutError', 'timeout_after', 'timeout_at', 'ignore_after', 'ignore_at', ) async def run_in_thread(func, *args): '''Run a function in a separate thread, and await its completion.''' return await get_event_loop().run_in_executor(None, func, *args) async def spawn(coro, *args, loop=None, daemon=False): return spawn_sync(coro, *args, loop=loop, daemon=daemon) def spawn_sync(coro, *args, loop=None, daemon=False): coro = instantiate_coroutine(coro, args) loop = loop or get_event_loop() task = loop.create_task(coro) task._daemon = daemon return task def safe_exception(task): try: return task.exception() except CancelledError as e: return e class NoRemainingTasksError(RuntimeError): pass class TaskGroup: '''A class representing a group of executing tasks. tasks is an optional set of existing tasks to put into the group. New tasks can later be added using the spawn() method below. wait specifies the policy used for waiting for tasks by the join() method. If wait is all then wait for all tasks to complete. If wait is any then wait for any task to complete and then cancel tasks that are still running. If wait is object then wait for the first task to return a non-None result and cancel tasks that are still runnning. None means wait for no tasks and cancel all still running. Completed tasks are normally dropped, but if retain is True, then a reference is kept so that the `results` and `exceptions` properties can be examined. To avoid runaway memory use, this should only be done for groups with a limited number of tasks. When join() is called, if any of the tasks in the group raises an exception or is cancelled then all tasks in the group, including daemon tasks, are cancelled. If the join() operation itself is cancelled then all running tasks in the group are also cancelled. Once join() returns all tasks have completed and new tasks may not be added. Tasks can be added while join() is waiting. A TaskGroup is often used as a context manager, which calls the join() method on context-exit. Each TaskGroup is an independent entity. Task groups do not form a hierarchy or any kind of relationship to other previously created task groups or tasks. Moreover, Tasks created by the top level spawn() function are not placed into any task group. To create a task in a group, it should be created using TaskGroup.spawn() or explicitly added using TaskGroup.add_task(). A task group has the following public attributes: completed: initially None, and set by join() to the first task in the group that finished. Tasks removed from the group by calls to next_done() (and if wait is object tasks returning None) do not count. joined: true if the task group join() operation has completed daemons: a set of all running daemonic tasks in the group. tasks: a set of all non-daemonic tasks in the group. ''' def __init__(self, tasks=(), *, wait=all, retain=False): if wait not in (any, all, object, None): raise ValueError('invalid wait argument') # Tasks that have not yet finished self._pending = set() # All non-daemonic tasks tracked by the group self.tasks = set() # All running deamonic tasks in the group self.daemons = set() # Non-daemonic tasks that have completed self._done = deque() self._wait = wait self._retain = retain self.joined = False self._semaphore = Semaphore(0) self.completed = None for task in tasks: self._add_task(task) def _on_done(self, task): task._task_group = None if getattr(task, '_daemon', False): self.daemons.discard(task) else: if not self._retain: self.tasks.remove(task) self._pending.discard(task) self._done.append(task) self._semaphore.release() def _add_task(self, task): '''Add an already existing task to the task group.''' if hasattr(task, '_task_group'): raise RuntimeError('task is already part of a group') if self.joined: raise RuntimeError('task group terminated') task._task_group = self daemon = getattr(task, '_daemon', False) if not daemon: self.tasks.add(task) if task.done(): self._on_done(task) elif daemon: self.daemons.add(task) else: self._pending.add(task) task.add_done_callback(self._on_done) @property def result(self): ''' The result of the first completed task. Should only be called after join() has returned.''' if not self.joined: raise RuntimeError('task group not yet terminated') if not self.completed: raise RuntimeError('no task successfully completed') return self.completed.result() @property def exception(self): ''' The exception of the first completed task. Should only be called after join() has returned.''' if not self.joined: raise RuntimeError('task group not yet terminated') return safe_exception(self.completed) if self.completed else None @property def results(self): '''A list of all results collected by join() in no particular order. If a task raised an exception or was cancelled then that exception will be raised. ''' if not self.joined: raise RuntimeError('task group not yet terminated') return [task.result() for task in self.tasks] @property def exceptions(self): '''A list of all exceptions collected by join() in no particular order.''' if not self.joined: raise RuntimeError('task group not yet terminated') return [safe_exception(task) for task in self.tasks] async def spawn(self, coro, *args, daemon=False): '''Create a new task and put it in the group. Returns a Task instance. Daemonic tasks are both ignored and cancelled by join(). ''' task = await spawn(coro, *args, daemon=daemon) self._add_task(task) return task async def add_task(self, task): '''Add an already existing task to the task group.''' self._add_task(task) async def next_done(self): '''Return the next completed task and remove it from the group. Return None if no more tasks remain. A TaskGroup may also be used as an asynchronous iterator. ''' if self._done or self._pending: await self._semaphore.acquire() if self._done: return self._done.popleft() return None async def next_result(self): '''Return the result of the next completed task and remove it from the group. If the task failed with an exception, that exception is raised. A RuntimeError exception is raised if no tasks remain. ''' task = await self.next_done() if not task: raise NoRemainingTasksError('no tasks remain') return task.result() async def join(self): '''Wait for tasks in the group to terminate according to the wait policy for the group. ''' try: # Wait for no-one; all tasks are cancelled if self._wait is None: return while True: task = await self.next_done() if task is None: return # Set self.completed if not yet set; unless wait is object and if self.completed is None: if not (self._wait is object and not safe_exception(task) and task.result() is None): self.completed = task if (safe_exception(task) or self._wait is any or (self._wait is object and self.completed)): return finally: # Cancel everything including daemons await self._cancel_tasks(self._pending.union(self.daemons)) self.joined = True async def _cancel_tasks(self, tasks): '''Cancel the passed set of tasks. Wait for them to complete.''' for task in tasks: task.cancel() if tasks: def pop_task(task): unfinished.remove(task) if not unfinished: all_done.set() unfinished = set(tasks) all_done = Event() for task in tasks: task.add_done_callback(pop_task) await all_done.wait() async def cancel_remaining(self): '''Cancel all remaining non-daemonic tasks and wait for them to complete. If any task blocks cancellation this routine will not return. ''' await self._cancel_tasks(self._pending) def __aiter__(self): return self async def __anext__(self): task = await self.next_done() if task: return task raise StopAsyncIteration async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_value, traceback): if exc_type: await self.cancel_remaining() await self.join() class TaskTimeout(Exception): def __init__(self, secs, *args): super().__init__(*args) self.secs = secs def __str__(self): return f'task timed out after {self.secs}s' class TimeoutCancellationError(CancelledError): pass class UncaughtTimeoutError(Exception): pass def _set_new_deadline(task, deadline): def timeout_task(): # Unfortunately task.cancel is all we can do with asyncio task.cancel() task._timed_out = deadline task._deadline_handle = task._loop.call_at(deadline, timeout_task) def _set_task_deadline(task, deadline): deadlines = getattr(task, '_deadlines', []) if deadlines: if deadline < min(deadlines): task._deadline_handle.cancel() _set_new_deadline(task, deadline) else: _set_new_deadline(task, deadline) deadlines.append(deadline) task._deadlines = deadlines task._timed_out = None def _unset_task_deadline(task): deadlines = task._deadlines timed_out_deadline = task._timed_out uncaught = timed_out_deadline not in deadlines task._deadline_handle.cancel() deadlines.pop() if deadlines: _set_new_deadline(task, min(deadlines)) return timed_out_deadline, uncaught class TimeoutAfter: def __init__(self, deadline, *, ignore=False, absolute=False): self._deadline = deadline self._ignore = ignore self._absolute = absolute self._secs = None self._task = None self.expired = False async def __aenter__(self): task = current_task() loop_time = task._loop.time() if self._absolute: self._secs = self._deadline - loop_time else: self._secs = self._deadline self._deadline += loop_time _set_task_deadline(task, self._deadline) self.expired = False self._task = task return self async def __aexit__(self, exc_type, exc_value, traceback): timed_out_deadline, uncaught = _unset_task_deadline(self._task) if exc_type not in (CancelledError, TaskTimeout, TimeoutCancellationError): return False if timed_out_deadline == self._deadline: self.expired = True if self._ignore: return True raise TaskTimeout(self._secs) from None if timed_out_deadline is None: return False if uncaught: raise UncaughtTimeoutError('uncaught timeout received') if exc_type is TimeoutCancellationError: return False raise TimeoutCancellationError(timed_out_deadline) from None async def _timeout_after_func(seconds, absolute, coro, args): coro = instantiate_coroutine(coro, args) async with TimeoutAfter(seconds, absolute=absolute): return await coro def timeout_after(seconds, coro=None, *args): '''Execute the specified coroutine and return its result. However, issue a cancellation request to the calling task after seconds have elapsed. When this happens, a TaskTimeout exception is raised. If coro is None, the result of this function serves as an asynchronous context manager that applies a timeout to a block of statements. timeout_after() may be composed with other timeout_after() operations (i.e., nested timeouts). If an outer timeout expires first, then TimeoutCancellationError is raised instead of TaskTimeout. If an inner timeout expires and fails to properly TaskTimeout, a UncaughtTimeoutError is raised in the outer timeout. ''' if coro: return _timeout_after_func(seconds, False, coro, args) return TimeoutAfter(seconds) def timeout_at(clock, coro=None, *args): '''Execute the specified coroutine and return its result. However, issue a cancellation request to the calling task after seconds have elapsed. When this happens, a TaskTimeout exception is raised. If coro is None, the result of this function serves as an asynchronous context manager that applies a timeout to a block of statements. timeout_after() may be composed with other timeout_after() operations (i.e., nested timeouts). If an outer timeout expires first, then TimeoutCancellationError is raised instead of TaskTimeout. If an inner timeout expires and fails to properly TaskTimeout, a UncaughtTimeoutError is raised in the outer timeout. ''' if coro: return _timeout_after_func(clock, True, coro, args) return TimeoutAfter(clock, absolute=True) async def _ignore_after_func(seconds, absolute, coro, args, timeout_result): coro = instantiate_coroutine(coro, args) async with TimeoutAfter(seconds, absolute=absolute, ignore=True): return await coro return timeout_result def ignore_after(seconds, coro=None, *args, timeout_result=None): '''Execute the specified coroutine and return its result. Issue a cancellation request after seconds have elapsed. When a timeout occurs, no exception is raised. Instead, timeout_result is returned. If coro is None, the result is an asynchronous context manager that applies a timeout to a block of statements. For the context manager case, the resulting context manager object has an expired attribute set to True if time expired. Note: ignore_after() may also be composed with other timeout operations. TimeoutCancellationError and UncaughtTimeoutError exceptions might be raised according to the same rules as for timeout_after(). ''' if coro: return _ignore_after_func(seconds, False, coro, args, timeout_result) return TimeoutAfter(seconds, ignore=True) def ignore_at(clock, coro=None, *args, timeout_result=None): ''' Stop the enclosed task or block of code at an absolute clock value. Same usage as ignore_after(). ''' if coro: return _ignore_after_func(clock, True, coro, args, timeout_result) return TimeoutAfter(clock, absolute=True, ignore=True) aiorpcX-0.24/aiorpcx/framing.py000077500000000000000000000210031474217261100165120ustar00rootroot00000000000000# Copyright (c) 2018, Neil Booth # # All rights reserved. # # The MIT License (MIT) # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the # "Software"), to deal in the Software without restriction, including # without limitation the rights to use, copy, modify, merge, publish, # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. '''RPC message framing in a byte stream.''' __all__ = ('FramerBase', 'NewlineFramer', 'BinaryFramer', 'BitcoinFramer', 'OversizedPayloadError', 'BadChecksumError', 'BadMagicError', ) from hashlib import sha256 as _sha256 from struct import Struct from .curio import Queue class FramerBase: '''Abstract base class for a framer. A framer breaks an incoming byte stream into protocol messages, buffering if necesary. It also frames outgoing messages into a byte stream. ''' def frame(self, message): '''Return the framed message.''' raise NotImplementedError def received_bytes(self, data): '''Pass incoming network bytes.''' raise NotImplementedError async def receive_message(self): '''Wait for a complete unframed message to arrive, and return it.''' raise NotImplementedError def fail(self, exception): '''Raise exception to receive_message.''' raise NotImplementedError class NewlineFramer(FramerBase): '''A framer for a protocol where messages are separated by newlines.''' # The default max_size value is motivated by JSONRPC, where a # normal request will be 250 bytes or less, and a reasonable # batch may contain 4000 requests. def __init__(self, max_size=250 * 4000): '''max_size - an anti-DoS measure. If, after processing an incoming message, buffered data would exceed max_size bytes, that buffered data is dropped entirely and the framer waits for a newline character to re-synchronize the stream. Set to zero to not limit the buffer size. ''' self.max_size = max_size self.queue = Queue() self.received_bytes = self.queue.put_nowait self.synchronizing = False self.residual = b'' self.exception = None def frame(self, message): return message + b'\n' def fail(self, exception): self.exception = exception self.received_bytes(b'') async def receive_message(self): parts = [] buffer_size = 0 while True: part = self.residual self.residual = b'' if not part: part = await self.queue.get() if self.exception: raise self.exception npos = part.find(b'\n') if npos == -1: parts.append(part) buffer_size += len(part) # Ignore over-sized messages; re-synchronize if buffer_size <= self.max_size or self.max_size == 0: continue self.synchronizing = True raise MemoryError(f'dropping message over {self.max_size:,d} ' f'bytes and re-synchronizing') tail, self.residual = part[:npos], part[npos + 1:] if self.synchronizing: self.synchronizing = False return await self.receive_message() else: parts.append(tail) return b''.join(parts) class ByteQueue(object): '''A producer-comsumer queue. Incoming network data is put as it arrives, and the consumer calls an async method waiting for data of a specific length.''' def __init__(self): self.queue = Queue() self.parts = [] self.parts_len = 0 self.put_nowait = self.queue.put_nowait self.exception = None def fail(self, exception): self.exception = exception self.put_nowait(b'') async def receive(self, size): if self.exception: raise self.exception while self.parts_len < size: part = await self.queue.get() if self.exception: raise self.exception self.parts.append(part) self.parts_len += len(part) self.parts_len -= size whole = b''.join(self.parts) self.parts = [whole[size:]] return whole[:size] class BinaryFramer(object): '''A framer for binary messaging protocols.''' def __init__(self): self.byte_queue = ByteQueue() self.message_queue = Queue() self.received_bytes = self.byte_queue.put_nowait self.fail = self.byte_queue.fail def frame(self, message): command, payload = message return b''.join(( self._build_header(command, payload), payload )) async def receive_message(self): command, payload_len, checksum = await self._receive_header() payload = await self.byte_queue.receive(payload_len) payload_checksum = self._checksum(payload) if payload_checksum != checksum: raise BadChecksumError(payload_checksum, checksum) return command, payload def _checksum(self, payload): raise NotImplementedError def _build_header(self, command, payload): raise NotImplementedError async def _receive_header(self): raise NotImplementedError # Helpers struct_le_I = Struct(' self.max_payload_size: if command != b'block' or payload_len > self._max_block_size: # Might be better to remove the payload raise OversizedPayloadError(command, payload_len) return command, payload_len, checksum aiorpcX-0.24/aiorpcx/jsonrpc.py000077500000000000000000000673561474217261100165720ustar00rootroot00000000000000# Copyright (c) 2018-2019, Neil Booth # # All rights reserved. # # The MIT License (MIT) # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the # "Software"), to deal in the Software without restriction, including # without limitation the rights to use, copy, modify, merge, publish, # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. '''Classes for JSONRPC versions 1.0 and 2.0, and a loose interpretation.''' __all__ = ('JSONRPC', 'JSONRPCv1', 'JSONRPCv2', 'JSONRPCLoose', 'JSONRPCAutoDetect', 'Request', 'Notification', 'Batch', 'RPCError', 'ProtocolError', 'JSONRPCConnection', 'handler_invocation') import itertools import json from functools import partial from numbers import Number from asyncio import get_event_loop from aiorpcx.util import signature_info class SingleRequest: __slots__ = ('method', 'args') def __init__(self, method, args): if not isinstance(method, str): raise ProtocolError(JSONRPC.METHOD_NOT_FOUND, 'method must be a string') if not isinstance(args, (list, tuple, dict)): raise ProtocolError.invalid_args('request arguments must be a ' 'list or a dictionary') self.args = args self.method = method def __repr__(self): return f'{self.__class__.__name__}({self.method!r}, {self.args!r})' def __eq__(self, other): return (isinstance(other, self.__class__) and self.method == other.method and self.args == other.args) class Request(SingleRequest): def send_result(self, _response): return None class Notification(SingleRequest): pass class Batch: __slots__ = ('items', ) def __init__(self, items): if not isinstance(items, (list, tuple)): raise ProtocolError.invalid_request('items must be a list') if not items: raise ProtocolError.empty_batch() if not (all(isinstance(item, SingleRequest) for item in items) or all(isinstance(item, Response) for item in items)): raise ProtocolError.invalid_request('batch must be homogeneous') self.items = items def __len__(self): return len(self.items) def __getitem__(self, item): return self.items[item] def __iter__(self): return iter(self.items) def __repr__(self): return f'Batch({len(self.items)} items)' class Response: __slots__ = ('result', ) def __init__(self, result): # Type checking happens when converting to a message self.result = result class CodeMessageError(Exception): '''Invoke as CodeMessageError(code, message)''' @property def code(self): return self.args[0] @property def message(self): return self.args[1] def __eq__(self, other): return (isinstance(other, self.__class__) and self.code == other.code and self.message == other.message) def __hash__(self): # overridden to make the exception hashable # see https://bugs.python.org/issue28603 return hash((self.code, self.message)) @classmethod def invalid_args(cls, message): return cls(JSONRPC.INVALID_ARGS, message) @classmethod def invalid_request(cls, message): return cls(JSONRPC.INVALID_REQUEST, message) @classmethod def empty_batch(cls): return cls.invalid_request('batch is empty') class RPCError(CodeMessageError): def __init__(self, code, message, *, cost=0.0): super().__init__(code, message) self.cost = cost class ProtocolError(CodeMessageError): def __init__(self, code, message): super().__init__(code, message) # If not None send this unframed message over the network self.error_message = None # If the error was in a JSON response message; its message ID. # Since None can be a response message ID, "id" means the # error was not sent in a JSON response self.response_msg_id = id class JSONRPC: '''Abstract base class that interprets and constructs JSON RPC messages.''' # Error codes. See http://www.jsonrpc.org/specification PARSE_ERROR = -32700 INVALID_REQUEST = -32600 METHOD_NOT_FOUND = -32601 INVALID_ARGS = -32602 INTERNAL_ERROR = -32603 # Codes specific to this library ERROR_CODE_UNAVAILABLE = -100 EXCESSIVE_RESOURCE_USAGE = -101 SERVER_BUSY = -102 # Can be overridden by derived classes allow_batches = True @classmethod def _message_id(cls, message, require_id): '''Validate the message is a dictionary and return its ID. Raise an error if the message is invalid or the ID is of an invalid type. If it has no ID, raise an error if require_id is True, otherwise return None. ''' raise NotImplementedError @classmethod def _validate_message(cls, message): '''Validate other parts of the message other than those done in _message_id.''' @classmethod def _request_args(cls, request): '''Validate the existence and type of the arguments passed in the request dictionary.''' raise NotImplementedError @classmethod def _process_request(cls, payload): request_id = None try: request_id = cls._message_id(payload, False) cls._validate_message(payload) method = payload.get('method') if request_id is None: item = Notification(method, cls._request_args(payload)) else: item = Request(method, cls._request_args(payload)) return item, request_id except ProtocolError as error: code, message = error.code, error.message raise cls._error(code, message, True, request_id) @classmethod def _process_response(cls, payload): request_id = None try: request_id = cls._message_id(payload, True) cls._validate_message(payload) return Response(cls.response_value(payload)), request_id except ProtocolError as error: code, message = error.code, error.message raise cls._error(code, message, False, request_id) @classmethod def _message_to_payload(cls, message): '''Returns a Python object or a ProtocolError.''' try: return json.loads(message.decode()) except UnicodeDecodeError: message = 'messages must be encoded in UTF-8' except json.JSONDecodeError: message = 'invalid JSON' raise cls._error(cls.PARSE_ERROR, message, True, None) @classmethod def _error(cls, code, message, send, msg_id): error = ProtocolError(code, message) if send: error.error_message = cls.response_message(error, msg_id) else: error.response_msg_id = msg_id return error # # External API # @classmethod def message_to_item(cls, message): '''Translate an unframed received message and return an (item, request_id) pair. The item can be a Request, Notification, Response or a list. A JSON RPC error response is returned as an RPCError inside a Response object. If a Batch is returned, request_id is an iterable of request ids, one per batch member. If the message violates the protocol in some way a ProtocolError is returned, except if the message was determined to be a response, in which case the ProtocolError is placed inside a Response object. This is so that client code can mark a request as having been responded to even if the response was bad. raises: ProtocolError ''' payload = cls._message_to_payload(message) if isinstance(payload, dict): if 'method' in payload: return cls._process_request(payload) else: return cls._process_response(payload) elif isinstance(payload, list) and cls.allow_batches: if not payload: raise cls._error(JSONRPC.INVALID_REQUEST, 'batch is empty', True, None) return payload, None raise cls._error(cls.INVALID_REQUEST, 'request object must be a dictionary', True, None) # Message formation @classmethod def request_message(cls, item, request_id): '''Convert an RPCRequest item to a message.''' assert isinstance(item, Request) return cls.encode_payload(cls.request_payload(item, request_id)) @classmethod def notification_message(cls, item): '''Convert an RPCRequest item to a message.''' assert isinstance(item, Notification) return cls.encode_payload(cls.request_payload(item, None)) @classmethod def response_message(cls, result, request_id): '''Convert a response result (or RPCError) to a message.''' if isinstance(result, CodeMessageError): payload = cls.error_payload(result, request_id) else: payload = cls.response_payload(result, request_id) return cls.encode_payload(payload) @classmethod def batch_message(cls, batch, request_ids): '''Convert a request Batch to a message.''' assert isinstance(batch, Batch) if not cls.allow_batches: raise ProtocolError.invalid_request( 'protocol does not permit batches') id_iter = iter(request_ids) rm = cls.request_message nm = cls.notification_message parts = (rm(request, next(id_iter)) if isinstance(request, Request) else nm(request) for request in batch) return cls.batch_message_from_parts(parts) @classmethod def batch_message_from_parts(cls, messages): '''Convert messages, one per batch item, into a batch message. At least one message must be passed. ''' # Comma-separate the messages and wrap the lot in square brackets middle = b', '.join(messages) if not middle: raise ProtocolError.empty_batch() return b''.join([b'[', middle, b']']) @classmethod def encode_payload(cls, payload): '''Encode a Python object as JSON and convert it to bytes.''' try: return json.dumps(payload, separators=(',', ':')).encode() except TypeError: msg = f'JSON payload encoding error: {payload}' raise ProtocolError(cls.INTERNAL_ERROR, msg) from None class JSONRPCv1(JSONRPC): '''JSON RPC version 1.0.''' allow_batches = False @classmethod def _message_id(cls, message, require_id): # JSONv1 requires an ID always, but without constraint on its type # No need to test for a dictionary here as we don't handle batches. if 'id' not in message: raise ProtocolError.invalid_request('request has no "id"') return message['id'] @classmethod def _request_args(cls, request): args = request.get('params') if not isinstance(args, list): raise ProtocolError.invalid_args( f'invalid request arguments: {args}') return args @classmethod def _best_effort_error(cls, error): # Do our best to interpret the error code = cls.ERROR_CODE_UNAVAILABLE message = 'no error message provided' if isinstance(error, str): message = error elif isinstance(error, int): code = error elif isinstance(error, dict): if isinstance(error.get('message'), str): message = error['message'] if isinstance(error.get('code'), int): code = error['code'] return RPCError(code, message) @classmethod def response_value(cls, payload): if 'result' not in payload or 'error' not in payload: raise ProtocolError.invalid_request( 'response must contain both "result" and "error"') result = payload['result'] error = payload['error'] if error is None: return result # It seems None can be a valid result if result is not None: raise ProtocolError.invalid_request( 'response has a "result" and an "error"') return cls._best_effort_error(error) @classmethod def request_payload(cls, request, request_id): '''JSON v1 request (or notification) payload.''' if isinstance(request.args, dict): raise ProtocolError.invalid_args( 'JSONRPCv1 does not support named arguments') return { 'method': request.method, 'params': request.args, 'id': request_id } @classmethod def response_payload(cls, result, request_id): '''JSON v1 response payload.''' return { 'result': result, 'error': None, 'id': request_id } @classmethod def error_payload(cls, error, request_id): return { 'result': None, 'error': {'code': error.code, 'message': error.message}, 'id': request_id } class JSONRPCv2(JSONRPC): '''JSON RPC version 2.0.''' @classmethod def _message_id(cls, message, require_id): if not isinstance(message, dict): raise ProtocolError.invalid_request( 'request object must be a dictionary') if 'id' in message: request_id = message['id'] if not isinstance(request_id, (Number, str, type(None))): raise ProtocolError.invalid_request( f'invalid "id": {request_id}') return request_id else: if require_id: raise ProtocolError.invalid_request('request has no "id"') return None @classmethod def _validate_message(cls, message): if message.get('jsonrpc') != '2.0': raise ProtocolError.invalid_request('"jsonrpc" is not "2.0"') @classmethod def _request_args(cls, request): args = request.get('params', []) if not isinstance(args, (dict, list)): raise ProtocolError.invalid_args( f'invalid request arguments: {args}') return args @classmethod def response_value(cls, payload): if 'result' in payload: if 'error' in payload: raise ProtocolError.invalid_request( 'response contains both "result" and "error"') return payload['result'] if 'error' not in payload: raise ProtocolError.invalid_request( 'response contains neither "result" nor "error"') # Return an RPCError object error = payload['error'] if isinstance(error, dict): code = error.get('code') message = error.get('message') if isinstance(code, int) and isinstance(message, str): return RPCError(code, message) raise ProtocolError.invalid_request( f'ill-formed response error object: {error}') @classmethod def request_payload(cls, request, request_id): '''JSON v2 request (or notification) payload.''' payload = { 'jsonrpc': '2.0', 'method': request.method, } # A notification? if request_id is not None: payload['id'] = request_id # Preserve empty dicts as missing params is read as an array if request.args or request.args == {}: payload['params'] = request.args return payload @classmethod def response_payload(cls, result, request_id): '''JSON v2 response payload.''' return { 'jsonrpc': '2.0', 'result': result, 'id': request_id } @classmethod def error_payload(cls, error, request_id): return { 'jsonrpc': '2.0', 'error': {'code': error.code, 'message': error.message}, 'id': request_id } class JSONRPCLoose(JSONRPC): '''A relaxed versin of JSON RPC.''' # Don't be so loose we accept any old message ID _message_id = JSONRPCv2._message_id _validate_message = JSONRPC._validate_message _request_args = JSONRPCv2._request_args # Outoing messages are JSONRPCv2 so we give the other side the # best chance to assume / detect JSONRPCv2 as default protocol. error_payload = JSONRPCv2.error_payload request_payload = JSONRPCv2.request_payload response_payload = JSONRPCv2.response_payload @classmethod def response_value(cls, payload): # Return result, unless it is None and there is an error if payload.get('error') is not None: if payload.get('result') is not None: raise ProtocolError.invalid_request( 'response contains both "result" and "error"') return JSONRPCv1._best_effort_error(payload['error']) if 'result' not in payload: raise ProtocolError.invalid_request( 'response contains neither "result" nor "error"') # Can be None return payload['result'] class JSONRPCAutoDetect(JSONRPCv2): @classmethod def detect_protocol(cls, message): '''Attempt to detect the protocol from the message.''' main = cls._message_to_payload(message) def protocol_for_payload(payload): if not isinstance(payload, dict): return JSONRPCLoose # Will error # Obey an explicit "jsonrpc" version = payload.get('jsonrpc') if version == '2.0': return JSONRPCv2 if version == '1.0': return JSONRPCv1 # Now to decide between JSONRPCLoose and JSONRPCv1 if possible if 'result' in payload and 'error' in payload: return JSONRPCv1 return JSONRPCLoose if isinstance(main, list): parts = set(protocol_for_payload(payload) for payload in main) # If all same protocol, return it if len(parts) == 1: return parts.pop() # If strict protocol detected, return it, preferring JSONRPCv2. # This means a batch of JSONRPCv1 will fail for protocol in (JSONRPCv2, JSONRPCv1): if protocol in parts: return protocol # Will error if no parts return JSONRPCLoose return protocol_for_payload(main) class JSONRPCConnection: '''Maintains state of a JSON RPC connection, in particular encapsulating the handling of request IDs. protocol - the JSON RPC protocol to follow max_response_size - responses over this size send an error response instead. ''' def __init__(self, protocol): self._protocol = protocol self._id_counter = itertools.count() # Sent Requests and Batches that have not received a response. # The key is its request ID; for a batch it is sorted tuple # of request IDs self._requests = {} self._create_future = get_event_loop().create_future # A public attribute intended to be settable dynamically self.max_response_size = 0 def _oversized_response_message(self, request_id): text = f'response too large (over {self.max_response_size:,d} bytes' error = RPCError.invalid_request(text) return self._protocol.response_message(error, request_id) def _receive_response(self, result, request_id): if request_id not in self._requests: if request_id is None and isinstance(result, RPCError): message = f'diagnostic error received: {result}' else: message = f'response to unsent request (ID: {request_id})' raise ProtocolError.invalid_request(message) from None _request, future = self._requests.pop(request_id) if not future.done(): if isinstance(result, Exception): future.set_exception(result) else: future.set_result(result) return [] def _receive_request_batch(self, payloads): def item_send_result(request_id, result): nonlocal size part = protocol.response_message(result, request_id) size += len(part) + 2 if size > self.max_response_size > 0: part = self._oversized_response_message(request_id) parts.append(part) if len(parts) == count: return protocol.batch_message_from_parts(parts) return None parts = [] items = [] size = 0 count = 0 protocol = self._protocol for payload in payloads: try: item, request_id = protocol._process_request(payload) items.append(item) if isinstance(item, Request): count += 1 item.send_result = partial(item_send_result, request_id) except ProtocolError as error: count += 1 parts.append(error.error_message) if not items and parts: error = ProtocolError(0, "") error.error_message = protocol.batch_message_from_parts(parts) raise error return items def _receive_response_batch(self, payloads): request_ids = [] results = [] for payload in payloads: # Let ProtocolError exceptions through item, request_id = self._protocol._process_response(payload) request_ids.append(request_id) results.append(item.result) ordered = sorted(zip(request_ids, results), key=lambda t: t[0]) ordered_ids, ordered_results = zip(*ordered) if ordered_ids not in self._requests: raise ProtocolError.invalid_request('response to unsent batch') _request_batch, future = self._requests.pop(ordered_ids) if not future.done(): future.set_result(ordered_results) return [] def _send_result(self, request_id, result): message = self._protocol.response_message(result, request_id) if len(message) > self.max_response_size > 0: message = self._oversized_response_message(request_id) return message def _future(self, request, request_id): future = self._create_future() self._requests[request_id] = (request, future) return future # # External API # def send_request(self, request): '''Send a Request. Return a (message, event) pair. The message is an unframed message to send over the network. Wait on the event for the response; which will be in the "result" attribute. Raises: ProtocolError if the request violates the protocol in some way.. ''' request_id = next(self._id_counter) message = self._protocol.request_message(request, request_id) return message, self._future(request, request_id) def send_notification(self, notification): return self._protocol.notification_message(notification) def send_batch(self, batch): ids = tuple(next(self._id_counter) for request in batch if isinstance(request, Request)) message = self._protocol.batch_message(batch, ids) event = self._future(batch, ids) if ids else None return message, event def receive_message(self, message): '''Call with an unframed message received from the network. Raises: ProtocolError if the message violates the protocol in some way. However, if it happened in a response that can be paired with a request, the ProtocolError is instead set in the result attribute of the send_request() that caused the error. ''' if self._protocol is JSONRPCAutoDetect: self._protocol = JSONRPCAutoDetect.detect_protocol(message) try: item, request_id = self._protocol.message_to_item(message) except ProtocolError as e: if e.response_msg_id is not id: return self._receive_response(e, e.response_msg_id) raise if isinstance(item, Request): item.send_result = partial(self._send_result, request_id) return [item] if isinstance(item, Notification): return [item] if isinstance(item, Response): return self._receive_response(item.result, request_id) assert isinstance(item, list) if all(isinstance(payload, dict) and ('result' in payload or 'error' in payload) for payload in item): return self._receive_response_batch(item) else: return self._receive_request_batch(item) def cancel_pending_requests(self): '''Cancel all pending requests.''' for _request, future in self._requests.values(): if not future.done(): future.cancel() self._requests.clear() def pending_requests(self): '''All sent requests that have not received a response.''' return [request for request, event in self._requests.values()] def handler_invocation(handler, request): method, args = request.method, request.args if handler is None: raise RPCError(JSONRPC.METHOD_NOT_FOUND, f'unknown method "{method}"') # We must test for too few and too many arguments. How # depends on whether the arguments were passed as a list or as # a dictionary. info = signature_info(handler) if isinstance(args, (tuple, list)): if len(args) < info.min_args: s = '' if len(args) == 1 else 's' raise RPCError.invalid_args( f'{len(args)} argument{s} passed to method ' f'"{method}" but it requires {info.min_args}') if info.max_args is not None and len(args) > info.max_args: s = '' if len(args) == 1 else 's' raise RPCError.invalid_args( f'{len(args)} argument{s} passed to method ' f'{method} taking at most {info.max_args}') return partial(handler, *args) # Arguments passed by name if info.other_names is None: raise RPCError.invalid_args(f'method "{method}" cannot ' f'be called with named arguments') missing = set(info.required_names).difference(args) if missing: s = '' if len(missing) == 1 else 's' missing = ', '.join(sorted(f'"{name}"' for name in missing)) raise RPCError.invalid_args(f'method "{method}" requires ' f'parameter{s} {missing}') if info.other_names is not any: excess = set(args).difference(info.required_names) excess = excess.difference(info.other_names) if excess: s = '' if len(excess) == 1 else 's' excess = ', '.join(sorted(f'"{name}"' for name in excess)) raise RPCError.invalid_args(f'method "{method}" does not ' f'take parameter{s} {excess}') return partial(handler, **args) aiorpcX-0.24/aiorpcx/rawsocket.py000077500000000000000000000144401474217261100171000ustar00rootroot00000000000000# Copyright (c) 2019, Neil Booth # # All rights reserved. # # The MIT License (MIT) # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the # "Software"), to deal in the Software without restriction, including # without limitation the rights to use, copy, modify, merge, publish, # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. '''Asyncio protocol abstraction.''' __all__ = ('connect_rs', 'serve_rs') import asyncio from functools import partial from aiorpcx.curio import Event, timeout_after, TaskTimeout from aiorpcx.session import RPCSession, SessionBase, SessionKind from aiorpcx.util import NetAddress class ConnectionLostError(Exception): pass class RSTransport(asyncio.Protocol): def __init__(self, session_factory, framer, kind): self.session_factory = session_factory self.loop = asyncio.get_event_loop() self.session = None self.kind = kind self._proxy = None self._asyncio_transport = None self._remote_address = None self._framer = framer # Cleared when the send socket is full self._can_send = Event() self._can_send.set() self._closed_event = Event() self._process_messages_task = None async def process_messages(self): try: await self.session.process_messages(self.receive_message) except ConnectionLostError: pass finally: self._closed_event.set() async def receive_message(self): return await self._framer.receive_message() def connection_made(self, transport): '''Called by asyncio when a connection is established.''' self._asyncio_transport = transport # If the Socks proxy was used then _proxy and _remote_address are already set if self._proxy is None: # This would throw if called on a closed SSL transport. Fixed in asyncio in # Python 3.6.1 and 3.5.4 peername = transport.get_extra_info('peername') if peername: self._remote_address = NetAddress(peername[0], peername[1]) self.session = self.session_factory(self) self._framer = self._framer or self.session.default_framer() self._process_messages_task = self.loop.create_task(self.process_messages()) def connection_lost(self, exc): '''Called by asyncio when the connection closes. Tear down things done in connection_made.''' # Release waiting tasks self._can_send.set() self._framer.fail(ConnectionLostError()) def data_received(self, data): '''Called by asyncio when a message comes in.''' self.session.data_received(data) self._framer.received_bytes(data) def pause_writing(self): '''Called by asyncio the send buffer is full.''' if not self.is_closing(): self._can_send.clear() self._asyncio_transport.pause_reading() def resume_writing(self): '''Called by asyncio the send buffer has room.''' if not self._can_send.is_set(): self._can_send.set() self._asyncio_transport.resume_reading() # API exposed to session async def write(self, message): await self._can_send.wait() if not self.is_closing(): framed_message = self._framer.frame(message) self._asyncio_transport.write(framed_message) async def close(self, force_after): '''Close the connection and return when closed.''' if self._asyncio_transport: self._asyncio_transport.close() try: async with timeout_after(force_after): await self._closed_event.wait() except TaskTimeout: await self.abort() await self._closed_event.wait() async def abort(self): if self._asyncio_transport: self._asyncio_transport.abort() def is_closing(self): '''Return True if the connection is closing.''' return self._closed_event.is_set() or self._asyncio_transport.is_closing() def proxy(self): return self._proxy def remote_address(self): return self._remote_address class RSClient: def __init__(self, host=None, port=None, proxy=None, *, framer=None, **kwargs): session_factory = kwargs.pop('session_factory', RPCSession) self.protocol_factory = partial(RSTransport, session_factory, framer, SessionKind.CLIENT) self.host = host self.port = port self.proxy = proxy self.session = None self.loop = kwargs.get('loop', asyncio.get_event_loop()) self.kwargs = kwargs async def create_connection(self): '''Initiate a connection.''' connector = self.proxy or self.loop return await connector.create_connection( self.protocol_factory, self.host, self.port, **self.kwargs) async def __aenter__(self): _transport, protocol = await self.create_connection() self.session = protocol.session assert isinstance(self.session, SessionBase) return self.session async def __aexit__(self, exc_type, exc_value, traceback): await self.session.close() async def serve_rs(session_factory, host=None, port=None, *, framer=None, loop=None, **kwargs): loop = loop or asyncio.get_event_loop() protocol_factory = partial(RSTransport, session_factory, framer, SessionKind.SERVER) return await loop.create_server(protocol_factory, host, port, **kwargs) connect_rs = RSClient aiorpcX-0.24/aiorpcx/session.py000077500000000000000000000505051474217261100165630ustar00rootroot00000000000000# Copyright (c) 2018-2019, Neil Booth # # All rights reserved. # # The MIT License (MIT) # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the # "Software"), to deal in the Software without restriction, including # without limitation the rights to use, copy, modify, merge, publish, # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. __all__ = ('RPCSession', 'MessageSession', 'ExcessiveSessionCostError', 'BatchError', 'Concurrency', 'ReplyAndDisconnect', 'SessionKind') import asyncio from enum import Enum import logging from math import ceil import time from aiorpcx.curio import ( TaskGroup, TaskTimeout, timeout_after, sleep ) from aiorpcx.framing import ( NewlineFramer, BitcoinFramer, BadMagicError, BadChecksumError, OversizedPayloadError ) from aiorpcx.jsonrpc import ( Request, Batch, Notification, ProtocolError, RPCError, JSONRPC, JSONRPCv2, JSONRPCConnection ) class ReplyAndDisconnect(Exception): '''Force a session disconnect after sending result (a Python object or an RPCError). ''' class ExcessiveSessionCostError(RuntimeError): pass class Concurrency: def __init__(self, target): self._target = int(target) self._semaphore = asyncio.Semaphore(self._target) self._sem_value = self._target async def _retarget_semaphore(self): if self._target <= 0: raise ExcessiveSessionCostError while self._sem_value < self._target: self._sem_value += 1 self._semaphore.release() @property def max_concurrent(self): return self._target def set_target(self, target): self._target = int(target) async def __aenter__(self): await self._semaphore.acquire() await self._retarget_semaphore() async def __aexit__(self, exc_type, exc_value, traceback): if self._sem_value > self._target: self._sem_value -= 1 else: self._semaphore.release() class SessionKind(Enum): CLIENT = 'client' SERVER = 'server' class SessionBase: '''Base class of networking sessions. There is no client / server distinction other than who initiated the connection. ''' # Multiply this by bandwidth bytes used to get resource usage cost bw_cost_per_byte = 1 / 100000 # If cost is over this requests begin to get delayed and concurrency is reduced cost_soft_limit = 2000 # If cost is over this the session is closed cost_hard_limit = 10000 # Resource usage is reduced by this every second cost_decay_per_sec = cost_hard_limit / 3600 # Request delay ranges from 0 to this between cost_soft_limit and cost_hard_limit cost_sleep = 2.0 # Base cost of an error. Errors that took resources to discover incur additional costs error_base_cost = 100.0 # Initial number of requests that can be concurrently processed initial_concurrent = 20 # Send a "server busy" error if processing a request takes longer than this seconds processing_timeout = 30.0 # Force-close a connection if its socket send buffer stays full this long max_send_delay = 20.0 def __init__(self, transport, *, loop=None): self.transport = transport self.loop = loop or asyncio.get_event_loop() self.logger = logging.getLogger(self.__class__.__name__) # For logger.debug messsages self.verbosity = 0 self._group = TaskGroup() # Statistics. The RPC object also keeps its own statistics. self.start_time = time.time() self.errors = 0 self.send_count = 0 self.send_size = 0 self.last_send = self.start_time self.recv_count = 0 self.recv_size = 0 self.last_recv = self.start_time # Resource usage self.cost = 0.0 self._cost_last = 0.0 self._cost_time = self.start_time self._cost_fraction = 0.0 # Concurrency control for incoming request handling self._incoming_concurrency = Concurrency(self.initial_concurrent) # By default, do not limit outgoing connections if self.session_kind == SessionKind.CLIENT: self.cost_hard_limit = 0 async def _send_message(self, message): if self.verbosity >= 4: self.logger.debug(f'sending message {message}') try: async with timeout_after(self.max_send_delay): await self.transport.write(message) except TaskTimeout: await self.abort() raise self.send_size += len(message) self.bump_cost(len(message) * self.bw_cost_per_byte) self.send_count += 1 self.last_send = time.time() return self.last_send def _bump_errors(self, exception=None): self.errors += 1 self.bump_cost(self.error_base_cost + getattr(exception, 'cost', 0.0)) @property def session_kind(self): '''Either client or server.''' return self.transport.kind async def connection_lost(self): pass def data_received(self, data): if self.verbosity >= 2: self.logger.debug(f'received data {data}') self.recv_size += len(data) self.bump_cost(len(data) * self.bw_cost_per_byte) def bump_cost(self, delta): # Delta can be positive or negative self.cost = max(0, self.cost + delta) if abs(self.cost - self._cost_last) > 100: self.recalc_concurrency() def on_disconnect_due_to_excessive_session_cost(self): '''Called just before disconnecting from the session, if it was consuming too much resources. ''' def recalc_concurrency(self): '''Call to recalculate sleeps and concurrency for the session. Called automatically if cost has drifted significantly. Otherwise can be called at regular intervals if desired. ''' # Refund resource usage proportionally to elapsed time; the bump passed is negative now = time.time() self.cost = max(0, self.cost - (now - self._cost_time) * self.cost_decay_per_sec) self._cost_time = now self._cost_last = self.cost # Setting cost_hard_limit <= 0 means to not limit concurrency value = self._incoming_concurrency.max_concurrent cost_soft_range = self.cost_hard_limit - self.cost_soft_limit if cost_soft_range <= 0: return cost = self.cost + self.extra_cost() self._cost_fraction = max(0.0, (cost - self.cost_soft_limit) / cost_soft_range) target = max(0, ceil((1.0 - self._cost_fraction) * self.initial_concurrent)) if abs(target - value) > 1: self.logger.info(f'changing task concurrency from {value} to {target}') self._incoming_concurrency.set_target(target) async def _process_messages(self, recv_message): try: await self._process_messages_loop(recv_message) finally: # Call the hook provided for derived classes await self.connection_lost() async def process_messages(self, recv_message): async with self._group as group: await group.spawn(self._process_messages, recv_message) # Remove tasks async for task in group: task.result() def unanswered_request_count(self): '''The number of requests received but not yet answered.''' # Max with zero in case the message processing task hasn't yet spawned return max(0, len(self._group._pending) - 1) def extra_cost(self): '''A dynamic value added to this session's cost when deciding how much to throttle requests. Can be negative. ''' return 0.0 def default_framer(self): '''Return a default framer.''' raise NotImplementedError def proxy(self): '''Returns the proxy used, or None.''' return self.transport.proxy() def remote_address(self): '''Returns a NetAddress or None if not connected.''' return self.transport.remote_address() def is_closing(self): '''Return True if the connection is closing.''' return self.transport.is_closing() async def abort(self): '''Forcefully close the connection.''' await self.transport.abort() async def close(self, *, force_after=30): '''Close the connection and return when closed.''' await self.transport.close(force_after) class MessageSession(SessionBase): '''Session class for protocols where messages are not tied to responses, such as the Bitcoin protocol. ''' async def _process_messages_loop(self, recv_message): while True: try: message = await recv_message() except BadMagicError as e: magic, expected = e.args self.logger.error( f'bad network magic: got {magic} expected {expected}, ' f'disconnecting' ) self._bump_errors(e) await self._group.spawn(self.close) await sleep(0.001) except OversizedPayloadError as e: command, payload_len = e.args self.logger.error( f'oversized payload of {payload_len:,d} bytes to command ' f'{command}, disconnecting' ) self._bump_errors(e) await self._group.spawn(self.close) await sleep(0.001) except BadChecksumError as e: payload_checksum, claimed_checksum = e.args self.logger.warning( f'checksum mismatch: actual {payload_checksum.hex()} ' f'vs claimed {claimed_checksum.hex()}' ) self._bump_errors(e) else: self.last_recv = time.time() self.recv_count += 1 await self._group.spawn(self._throttled_message(message)) async def _throttled_message(self, message): '''Process a single request, respecting the concurrency limit.''' try: timeout = self.processing_timeout async with timeout_after(timeout): async with self._incoming_concurrency: if self._cost_fraction: await sleep(self._cost_fraction * self.cost_sleep) await self.handle_message(message) except ProtocolError as e: self.logger.error(f'{e}') self._bump_errors(e) except TaskTimeout: self.logger.info(f'incoming request timed out after {timeout} secs') self._bump_errors() except ExcessiveSessionCostError: self.on_disconnect_due_to_excessive_session_cost() await self.close() except Exception: self.logger.exception(f'exception handling {message}') self._bump_errors() def default_framer(self): '''Return a bitcoin framer.''' return BitcoinFramer() async def handle_message(self, message): '''message is a (command, payload) pair.''' async def send_message(self, message): '''Send a message (command, payload) over the network.''' await self._send_message(message) class BatchError(Exception): def __init__(self, request): super().__init__(request) self.request = request # BatchRequest object class BatchRequest: '''Used to build a batch request to send to the server. Stores the Attributes batch and results are initially None. Adding an invalid request or notification immediately raises a ProtocolError. On exiting the with clause, it will: 1) create a Batch object for the requests in the order they were added. If the batch is empty this raises a ProtocolError. 2) set the "batch" attribute to be that batch 3) send the batch request and wait for a response 4) raise a ProtocolError if the protocol was violated by the server. Currently this only happens if it gave more than one response to any request 5) otherwise there is precisely one response to each Request. Set the "results" attribute to the tuple of results; the responses are ordered to match the Requests in the batch. Notifications do not get a response. 6) if raise_errors is True and any individual response was a JSON RPC error response, or violated the protocol in some way, a BatchError exception is raised. Otherwise the caller can be certain each request returned a standard result. ''' def __init__(self, session, raise_errors): self._session = session self._raise_errors = raise_errors self._requests = [] self.batch = None self.results = None def add_request(self, method, args=()): self._requests.append(Request(method, args)) def add_notification(self, method, args=()): self._requests.append(Notification(method, args)) def __len__(self): return len(self._requests) async def __aenter__(self): return self async def __aexit__(self, exc_type, exc_value, traceback): if exc_type is None: self.batch = Batch(self._requests) message, future = self._session.connection.send_batch(self.batch) self.results = await self._session._send_concurrent(message, future, len(self.batch)) if self._raise_errors: if any(isinstance(item, Exception) for item in self.results): raise BatchError(self) class RPCSession(SessionBase): '''Base class for protocols where a message can lead to a response, for example JSON RPC.''' # Adjust outgoing request concurrency to target a round trip response time of # this many seconds, recalibrating every recalibrate_count requests target_response_time = 3.0 recalibrate_count = 30 # Raise a TaskTimeout if getting a response takes longer than this sent_request_timeout = 30.0 log_me = False def __init__(self, transport, *, loop=None, connection=None): super().__init__(transport, loop=loop) self.connection = connection or self.default_connection() # Concurrency control for outgoing request sending self._outgoing_concurrency = Concurrency(50) self._req_times = [] def _recalc_concurrency(self): req_times = self._req_times avg = sum(req_times) / len(req_times) req_times.clear() current = self._outgoing_concurrency.max_concurrent cap = min(current + max(3, current * 0.1), 250) floor = max(1, min(current * 0.8, current - 1)) if avg != 0: target = max(floor, min(cap, current * self.target_response_time / avg)) else: target = cap target = int(0.5 + target) if target != current: self.logger.info(f'changing outgoing request concurrency to {target} from {current}') self._outgoing_concurrency.set_target(target) async def _process_messages_loop(self, recv_message): # The loop will exit when recv_message raises a ConnectionLost error; which is also # arranged when close() is called. while True: try: message = await recv_message() except MemoryError as e: self.logger.warning(f'{e!r}') continue self.last_recv = time.time() self.recv_count += 1 if self.log_me: self.logger.info(f'processing {message}') try: requests = self.connection.receive_message(message) except ProtocolError as e: self.logger.debug(str(e)) if e.code == JSONRPC.PARSE_ERROR: e.cost = self.error_base_cost * 10 self._bump_errors(e) if e.error_message: await self._send_message(e.error_message) else: for request in requests: await self._group.spawn(self._throttled_request(request)) async def _throttled_request(self, request): '''Process a single request, respecting the concurrency limit.''' disconnect = False try: timeout = self.processing_timeout async with timeout_after(timeout): async with self._incoming_concurrency: if self._cost_fraction: await sleep(self._cost_fraction * self.cost_sleep) result = await self.handle_request(request) except (ProtocolError, RPCError) as e: result = e except TaskTimeout: self.logger.info(f'incoming request {request} timed out after {timeout} secs') result = RPCError(JSONRPC.SERVER_BUSY, 'server busy - request timed out') except ReplyAndDisconnect as e: result = e.args[0] disconnect = True except ExcessiveSessionCostError: self.on_disconnect_due_to_excessive_session_cost() result = RPCError(JSONRPC.EXCESSIVE_RESOURCE_USAGE, 'excessive resource usage') disconnect = True except Exception: self.logger.exception(f'exception handling {request}') result = RPCError(JSONRPC.INTERNAL_ERROR, 'internal server error') if isinstance(request, Request): message = request.send_result(result) if message: await self._send_message(message) if isinstance(result, Exception): self._bump_errors(result) if disconnect: await self.close() async def _send_concurrent(self, message, future, request_count): async with self._outgoing_concurrency: send_time = await self._send_message(message) try: async with timeout_after(self.sent_request_timeout): return await future finally: time_taken = max(0, time.time() - send_time) if request_count == 1: self._req_times.append(time_taken) else: self._req_times.extend([time_taken / request_count] * request_count) if len(self._req_times) >= self.recalibrate_count: self._recalc_concurrency() # External API async def connection_lost(self): self.connection.cancel_pending_requests() def default_connection(self): '''Return a default connection if the user provides none.''' return JSONRPCConnection(JSONRPCv2) def default_framer(self): '''Return a default framer.''' return NewlineFramer() async def handle_request(self, request): pass async def send_request(self, method, args=()): '''Send an RPC request over the network.''' message, future = self.connection.send_request(Request(method, args)) return await self._send_concurrent(message, future, 1) async def send_notification(self, method, args=()): '''Send an RPC notification over the network.''' message = self.connection.send_notification(Notification(method, args)) await self._send_message(message) def send_batch(self, raise_errors=False): '''Return a BatchRequest. Intended to be used like so: async with session.send_batch() as batch: batch.add_request("method1") batch.add_request("sum", (x, y)) batch.add_notification("updated") for result in batch.results: ... Note that in some circumstances exceptions can be raised; see BatchRequest doc string. ''' return BatchRequest(self, raise_errors) aiorpcX-0.24/aiorpcx/socks.py000077500000000000000000000363661474217261100162330ustar00rootroot00000000000000# Copyright (c) 2018, Neil Booth # # All rights reserved. # # The MIT License (MIT) # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the # "Software"), to deal in the Software without restriction, including # without limitation the rights to use, copy, modify, merge, publish, # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. '''SOCKS proxying.''' import asyncio import collections from ipaddress import IPv4Address, IPv6Address import secrets import socket import struct from functools import partial from aiorpcx.util import NetAddress __all__ = ('SOCKSUserAuth', 'SOCKSRandomAuth', 'SOCKS4', 'SOCKS4a', 'SOCKS5', 'SOCKSProxy', 'SOCKSError', 'SOCKSProtocolError', 'SOCKSFailure') SOCKSUserAuth = collections.namedtuple("SOCKSUserAuth", "username password") # Random authentication is useful when used with Tor for stream isolation. class SOCKSRandomAuth(SOCKSUserAuth): def __getattribute__(self, key): return secrets.token_hex(32) SOCKSRandomAuth.__new__.__defaults__ = (None, None) class SOCKSError(Exception): '''Base class for SOCKS exceptions. Each raised exception will be an instance of a derived class.''' class SOCKSProtocolError(SOCKSError): '''Raised when the proxy does not follow the SOCKS protocol''' class SOCKSFailure(SOCKSError): '''Raised when the proxy refuses or fails to make a connection''' class NeedData(Exception): pass class SOCKSBase: '''Stateful as written so good for a single connection only.''' @classmethod def name(cls): return cls.__name__ def __init__(self): self._buffer = bytes() self._state = self._start def _read(self, size): if len(self._buffer) < size: raise NeedData(size - len(self._buffer)) result = self._buffer[:size] self._buffer = self._buffer[size:] return result def receive_data(self, data): self._buffer += data def next_message(self): return self._state() class SOCKS4(SOCKSBase): '''SOCKS4 protocol wrapper.''' # See http://ftp.icm.edu.pl/packages/socks/socks4/SOCKS4.protocol REPLY_CODES = { 90: 'request granted', 91: 'request rejected or failed', 92: ('request rejected because SOCKS server cannot connect ' 'to identd on the client'), 93: ('request rejected because the client program and identd ' 'report different user-ids') } def __init__(self, remote_address, auth): super().__init__() self._remote_host = remote_address.host self._remote_port = remote_address.port self._auth = auth self._check_remote_host() def _check_remote_host(self): if not isinstance(self._remote_host, IPv4Address): raise SOCKSProtocolError(f'SOCKS4 requires an IPv4 address: {self._remote_host}') def _start(self): self._state = self._first_response if isinstance(self._remote_host, IPv4Address): # SOCKS4 dst_ip_packed = self._remote_host.packed host_bytes = b'' else: # SOCKS4a dst_ip_packed = b'\0\0\0\1' host_bytes = self._remote_host.encode() + b'\0' if isinstance(self._auth, SOCKSUserAuth): user_id = self._auth.username.encode() else: user_id = b'' # Send TCP/IP stream CONNECT request return b''.join([b'\4\1', struct.pack('>H', self._remote_port), dst_ip_packed, user_id, b'\0', host_bytes]) def _first_response(self): # Wait for 8-byte response data = self._read(8) if data[0] != 0: raise SOCKSProtocolError(f'invalid {self.name()} proxy ' f'response: {data}') reply_code = data[1] if reply_code != 90: msg = self.REPLY_CODES.get( reply_code, f'unknown {self.name()} reply code {reply_code}') raise SOCKSFailure(f'{self.name()} proxy request failed: {msg}') # Other fields ignored return None class SOCKS4a(SOCKS4): def _check_remote_host(self): if not isinstance(self._remote_host, (str, IPv4Address)): raise SOCKSProtocolError( f'SOCKS4a requires an IPv4 address or host name: {self._remote_host}') class SOCKS5(SOCKSBase): '''SOCKS protocol wrapper.''' # See https://tools.ietf.org/html/rfc1928 ERROR_CODES = { 1: 'general SOCKS server failure', 2: 'connection not allowed by ruleset', 3: 'network unreachable', 4: 'host unreachable', 5: 'connection refused', 6: 'TTL expired', 7: 'command not supported', 8: 'address type not supported', } def __init__(self, remote_address, auth): super().__init__() self._dst_bytes = SOCKS5._destination_bytes(remote_address.host, remote_address.port) self._auth_bytes, self._auth_methods = SOCKS5._authentication(auth) @staticmethod def _destination_bytes(host, port): if isinstance(host, IPv4Address): addr_bytes = b'\1' + host.packed elif isinstance(host, IPv6Address): addr_bytes = b'\4' + host.packed else: assert isinstance(host, str) host = host.encode() assert len(host) <= 255 addr_bytes = b'\3' + bytes([len(host)]) + host return addr_bytes + struct.pack('>H', port) @staticmethod def _authentication(auth): if isinstance(auth, SOCKSUserAuth): user_bytes = auth.username.encode() if not 0 < len(user_bytes) < 256: raise SOCKSProtocolError(f'username {auth.username} has ' f'invalid length {len(user_bytes)}') pwd_bytes = auth.password.encode() if not 0 < len(pwd_bytes) < 256: raise SOCKSProtocolError(f'password has invalid length ' f'{len(pwd_bytes)}') return b''.join([bytes([1, len(user_bytes)]), user_bytes, bytes([len(pwd_bytes)]), pwd_bytes]), [0, 2] return b'', [0] def _start(self): self._state = self._first_response return (b'\5' + bytes([len(self._auth_methods)]) + bytes(m for m in self._auth_methods)) def _first_response(self): # Wait for 2-byte response data = self._read(2) if data[0] != 5: raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}') if data[1] not in self._auth_methods: raise SOCKSFailure('SOCKS5 proxy rejected authentication methods') # Authenticate if user-password authentication if data[1] == 2: self._state = self._auth_response return self._auth_bytes return self._request_connection() def _auth_response(self): data = self._read(2) if data[0] != 1: raise SOCKSProtocolError(f'invalid SOCKS5 proxy auth ' f'response: {data}') if data[1] != 0: raise SOCKSFailure(f'SOCKS5 proxy auth failure code: ' f'{data[1]}') return self._request_connection() def _request_connection(self): # Send connection request self._state = self._connect_response return b'\5\1\0' + self._dst_bytes def _connect_response(self): data = self._read(5) if data[0] != 5 or data[2] != 0 or data[3] not in (1, 3, 4): raise SOCKSProtocolError(f'invalid SOCKS5 proxy response: {data}') if data[1] != 0: raise SOCKSFailure(self.ERROR_CODES.get( data[1], f'unknown SOCKS5 error code: {data[1]}')) if data[3] == 1: addr_len = 3 # IPv4 elif data[3] == 3: addr_len = data[4] # Hostname else: addr_len = 15 # IPv6 self._state = partial(self._connect_response_rest, addr_len) return self.next_message() def _connect_response_rest(self, addr_len): self._read(addr_len + 2) return None class SOCKSProxy: def __init__(self, address, protocol, auth): '''A SOCKS proxy at a NetAddress following a SOCKS protocol. auth is an authentication method to use when connecting, or None. ''' if not isinstance(address, NetAddress): address = NetAddress.from_string(address) self.address = address self.protocol = protocol self.auth = auth # Set on each successful connection via the proxy to the # result of socket.getpeername() self.peername = None def __str__(self): auth = 'username' if self.auth else 'none' return f'{self.protocol.name()} proxy at {self.address}, auth: {auth}' async def _handshake(self, client, sock, loop): while True: count = 0 try: message = client.next_message() except NeedData as e: count = e.args[0] else: if message is None: return await loop.sock_sendall(sock, message) if count: data = await loop.sock_recv(sock, count) if not data: raise SOCKSProtocolError("EOF received") client.receive_data(data) async def _connect_one(self, remote_address): '''Connect to the proxy and perform a handshake requesting a connection. Return the open socket on success, or the exception on failure. ''' loop = asyncio.get_event_loop() for info in await loop.getaddrinfo(str(self.address.host), self.address.port, type=socket.SOCK_STREAM): # This object has state so is only good for one connection client = self.protocol(remote_address, self.auth) sock = socket.socket(family=info[0]) try: # A non-blocking socket is required by loop socket methods sock.setblocking(False) await loop.sock_connect(sock, info[4]) await self._handshake(client, sock, loop) self.peername = sock.getpeername() return sock except (OSError, SOCKSError) as e: exception = e # Don't close the socket because of an asyncio bug # see https://github.com/kyuupichan/aiorpcX/issues/8 return exception async def _connect(self, remote_addresses): '''Connect to the proxy and perform a handshake requesting a connection to each address in addresses. Return an (open_socket, remote_address) pair on success. ''' assert remote_addresses exceptions = [] for remote_address in remote_addresses: sock = await self._connect_one(remote_address) if isinstance(sock, socket.socket): return sock, remote_address exceptions.append(sock) strings = set(f'{exc!r}' for exc in exceptions) raise (exceptions[0] if len(strings) == 1 else OSError(f'multiple exceptions: {", ".join(strings)}')) async def _detect_proxy(self): '''Return True if it appears we can connect to a SOCKS proxy, otherwise False. ''' if self.protocol is SOCKS4a: remote_address = NetAddress('www.apple.com', 80) else: remote_address = NetAddress('8.8.8.8', 53) sock = await self._connect_one(remote_address) if isinstance(sock, socket.socket): sock.close() return True # SOCKSFailure indicates something failed, but that we are likely talking to a # proxy return isinstance(sock, SOCKSFailure) @classmethod async def auto_detect_at_address(cls, address, auth): '''Try to detect a SOCKS proxy at address using the authentication method (or None). SOCKS5, SOCKS4a and SOCKS are tried in order. If a SOCKS proxy is detected a SOCKSProxy object is returned. Returning a SOCKSProxy does not mean it is functioning - for example, it may have no network connectivity. If no proxy is detected return None. ''' for protocol in (SOCKS5, SOCKS4a, SOCKS4): proxy = cls(address, protocol, auth) if await proxy._detect_proxy(): return proxy return None @classmethod async def auto_detect_at_host(cls, host, ports, auth): '''Try to detect a SOCKS proxy on a host on one of the ports. Calls auto_detect_address for the ports in order. Returning a SOCKSProxy does not mean it is functioning - for example, it may have no network connectivity. If no proxy is detected return None. ''' for port in ports: proxy = await cls.auto_detect_at_address(NetAddress(host, port), auth) if proxy: return proxy return None async def create_connection(self, protocol_factory, host, port, *, resolve=False, ssl=None, family=0, proto=0, flags=0): '''Set up a connection to (host, port) through the proxy. If resolve is True then host is resolved locally with getaddrinfo using family, proto and flags, otherwise the proxy is asked to resolve host. The function signature is similar to loop.create_connection() with the same result. The attribute _address is set on the protocol to the address of the successful remote connection. Additionally raises SOCKSError if something goes wrong with the proxy handshake. ''' loop = asyncio.get_event_loop() if resolve: remote_addresses = [NetAddress(info[4][0], info[4][1]) for info in await loop.getaddrinfo(host, port, family=family, proto=proto, type=socket.SOCK_STREAM, flags=flags)] else: remote_addresses = [NetAddress(host, port)] sock, remote_address = await self._connect(remote_addresses) def set_address(): protocol = protocol_factory() protocol._proxy = self protocol._remote_address = remote_address return protocol return await loop.create_connection(set_address, sock=sock, ssl=ssl, server_hostname=host if ssl else None) aiorpcX-0.24/aiorpcx/unixsocket.py000077500000000000000000000133531474217261100172740ustar00rootroot00000000000000# Copyright (c) 2021, Adriano Marto Reis # # All rights reserved. # # The MIT License (MIT) # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the # "Software"), to deal in the Software without restriction, including # without limitation the rights to use, copy, modify, merge, publish, # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. '''Asyncio protocol abstraction.''' __all__ = ('connect_us', 'serve_us') import asyncio from functools import partial from aiorpcx.curio import Event, timeout_after, TaskTimeout from aiorpcx.session import RPCSession, SessionBase, SessionKind class ConnectionLostError(Exception): pass class USTransport(asyncio.Protocol): def __init__(self, session_factory, framer, kind): self.session_factory = session_factory self.loop = asyncio.get_event_loop() self.session = None self.kind = kind self._asyncio_transport = None self._framer = framer # Cleared when the send socket is full self._can_send = Event() self._can_send.set() self._closed_event = Event() self._process_messages_task = None async def process_messages(self): try: await self.session.process_messages(self.receive_message) except ConnectionLostError: pass finally: self._closed_event.set() async def receive_message(self): return await self._framer.receive_message() def connection_made(self, transport): '''Called by asyncio when a connection is established.''' self._asyncio_transport = transport self.session = self.session_factory(self) self._framer = self._framer or self.session.default_framer() self._process_messages_task = self.loop.create_task(self.process_messages()) def connection_lost(self, _exeption): '''Called by asyncio when the connection closes. Tear down things done in connection_made.''' # Release waiting tasks self._can_send.set() self._framer.fail(ConnectionLostError()) def data_received(self, data): '''Called by asyncio when a message comes in.''' self.session.data_received(data) self._framer.received_bytes(data) def pause_writing(self): '''Called by asyncio the send buffer is full.''' if not self.is_closing(): self._can_send.clear() self._asyncio_transport.pause_reading() def resume_writing(self): '''Called by asyncio the send buffer has room.''' if not self._can_send.is_set(): self._can_send.set() self._asyncio_transport.resume_reading() # API exposed to session async def write(self, message): await self._can_send.wait() if not self.is_closing(): framed_message = self._framer.frame(message) self._asyncio_transport.write(framed_message) async def close(self, force_after): '''Close the connection and return when closed.''' if self._asyncio_transport: self._asyncio_transport.close() try: async with timeout_after(force_after): await self._closed_event.wait() except TaskTimeout: await self.abort() await self._closed_event.wait() async def abort(self): if self._asyncio_transport: self._asyncio_transport.abort() def is_closing(self): '''Return True if the connection is closing.''' return self._closed_event.is_set() or self._asyncio_transport.is_closing() def proxy(self): '''Not applicable to unix sockets.''' return None def remote_address(self): '''Not applicable to unix sockets''' return None class USClient: def __init__(self, path=None, *, framer=None, **kwargs): session_factory = kwargs.pop('session_factory', RPCSession) self.protocol_factory = partial(USTransport, session_factory, framer, SessionKind.CLIENT) self.path = path self.session = None self.loop = kwargs.get('loop', asyncio.get_event_loop()) self.kwargs = kwargs async def create_connection(self): '''Initiate a connection.''' return await self.loop.create_unix_connection( self.protocol_factory, self.path, **self.kwargs) async def __aenter__(self): _transport, protocol = await self.create_connection() self.session = protocol.session assert isinstance(self.session, SessionBase) return self.session async def __aexit__(self, _type, _value, _traceback): await self.session.close() async def serve_us(session_factory, path=None, *, framer=None, loop=None, **kwargs): loop = loop or asyncio.get_event_loop() protocol_factory = partial(USTransport, session_factory, framer, SessionKind.SERVER) return await loop.create_unix_server(protocol_factory, path, **kwargs) connect_us = USClient aiorpcX-0.24/aiorpcx/util.py000077500000000000000000000231741474217261100160570ustar00rootroot00000000000000# Copyright (c) 2018, Neil Booth # # All rights reserved. # # The MIT License (MIT) # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the # "Software"), to deal in the Software without restriction, including # without limitation the rights to use, copy, modify, merge, publish, # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. __all__ = ('instantiate_coroutine', 'is_valid_hostname', 'classify_host', 'validate_port', 'validate_protocol', 'Service', 'ServicePart', 'NetAddress') import asyncio from collections import namedtuple from enum import IntEnum from functools import partial import inspect from ipaddress import ip_address, IPv4Address, IPv6Address import re # See http://stackoverflow.com/questions/2532053/validate-a-hostname-string # Note underscores are valid in domain names, but strictly invalid in host # names. We ignore that distinction. PROTOCOL_REGEX = re.compile('[A-Za-z][A-Za-z0-9+-.]+$') LABEL_REGEX = re.compile('^[a-z0-9_]([a-z0-9-_]{0,61}[a-z0-9_])?$', re.IGNORECASE) NUMERIC_REGEX = re.compile('[0-9]+$') def is_valid_hostname(hostname): '''Return True if hostname is valid, otherwise False.''' if not isinstance(hostname, str): raise TypeError('hostname must be a string') # strip exactly one dot from the right, if present if hostname and hostname[-1] == ".": hostname = hostname[:-1] if not hostname or len(hostname) > 253: return False labels = hostname.split('.') # the TLD must be not all-numeric if re.match(NUMERIC_REGEX, labels[-1]): return False return all(LABEL_REGEX.match(label) for label in labels) def classify_host(host): '''Host is an IPv4Address, IPv6Address or a string. If an IPv4Address or IPv6Address return it. Otherwise convert the string to an IPv4Address or IPv6Address object if possible and return it. Otherwise return the original string if it is a valid hostname. Raise ValueError if a string cannot be interpreted as an IP address and it is not a valid hostname. ''' if isinstance(host, (IPv4Address, IPv6Address)): return host if is_valid_hostname(host): return host return ip_address(host) def validate_port(port): '''Validate port and return it as an integer. A string, or its representation as an integer, is accepted.''' if not isinstance(port, (str, int)): raise TypeError(f'port must be an integer or string: {port}') if isinstance(port, str) and port.isdigit(): port = int(port) if isinstance(port, int) and 0 < port <= 65535: return port raise ValueError(f'invalid port: {port}') def validate_protocol(protocol): '''Validate a protocol, a string, and return it.''' if not re.match(PROTOCOL_REGEX, protocol): raise ValueError(f'invalid protocol: {protocol}') return protocol.lower() class ServicePart(IntEnum): PROTOCOL = 0 HOST = 1 PORT = 2 def _split_address(string): if string.startswith('['): end = string.find(']') if end != -1: if len(string) == end + 1: return string[1:end], '' if string[end + 1] == ':': return string[1:end], string[end+2:] colon = string.find(':') if colon == -1: return string, '' return string[:colon], string[colon + 1:] class NetAddress: def __init__(self, host, port): '''Construct a NetAddress from a host and a port. Host is classified and port is an integer.''' self._host = classify_host(host) self._port = validate_port(port) def __eq__(self, other): return self._host == other._host and self._port == other._port def __hash__(self): return hash((self._host, self._port)) @classmethod def from_string(cls, string, *, default_func=None): '''Construct a NetAddress from a string and return a (host, port) pair. If either (or both) is missing and default_func is provided, it is called with ServicePart.HOST or ServicePart.PORT to get a default. ''' if not isinstance(string, str): raise TypeError(f'address must be a string: {string}') host, port = _split_address(string) if default_func: host = host or default_func(ServicePart.HOST) port = port or default_func(ServicePart.PORT) if not host or not port: raise ValueError(f'invalid address string: {string}') return cls(host, port) @property def host(self): return self._host @property def port(self): return self._port def __str__(self): if isinstance(self._host, IPv6Address): return f'[{self._host}]:{self._port}' return f'{self.host}:{self.port}' def __repr__(self): return f'NetAddress({self.host!r}, {self.port})' @classmethod def default_host_and_port(cls, host, port): def func(kind): return host if kind == ServicePart.HOST else port return func @classmethod def default_host(cls, host): return cls.default_host_and_port(host, None) @classmethod def default_port(cls, port): return cls.default_host_and_port(None, port) class Service: '''A validated protocol, address pair.''' def __init__(self, protocol, address): '''Construct a service from a protocol string and a NetAddress object,''' self._protocol = validate_protocol(protocol) if not isinstance(address, NetAddress): address = NetAddress.from_string(address) self._address = address def __eq__(self, other): return self._protocol == other._protocol and self._address == other._address def __hash__(self): return hash((self._protocol, self._address)) @property def protocol(self): return self._protocol @property def address(self): return self._address @property def host(self): return self._address.host @property def port(self): return self._address.port @classmethod def from_string(cls, string, *, default_func=None): '''Construct a Service from a string. If default_func is provided and any ServicePart is missing, it is called with default_func(protocol, part) to obtain the missing part. ''' if not isinstance(string, str): raise TypeError(f'service must be a string: {string}') parts = string.split('://', 1) if len(parts) == 2: protocol, address = parts else: item, = parts protocol = None if default_func: if default_func(item, ServicePart.HOST) and default_func(item, ServicePart.PORT): protocol, address = item, '' else: protocol, address = default_func(None, ServicePart.PROTOCOL), item if not protocol: raise ValueError(f'invalid service string: {string}') if default_func: default_func = partial(default_func, protocol.lower()) address = NetAddress.from_string(address, default_func=default_func) return cls(protocol, address) def __str__(self): return f'{self._protocol}://{self._address}' def __repr__(self): return f"Service({self._protocol!r}, '{self._address}')" def instantiate_coroutine(corofunc, args): if asyncio.iscoroutine(corofunc): if args != (): raise ValueError('args cannot be passed with a coroutine') return corofunc return corofunc(*args) def is_async_call(func): '''inspect.iscoroutinefunction that looks through partials.''' while isinstance(func, partial): func = func.func return inspect.iscoroutinefunction(func) # other_params: None means cannot be called with keyword arguments only # any means any name is good SignatureInfo = namedtuple('SignatureInfo', 'min_args max_args ' 'required_names other_names') def signature_info(func): params = inspect.signature(func).parameters min_args = max_args = 0 required_names = [] other_names = [] no_names = False for p in params.values(): if p.kind == p.POSITIONAL_OR_KEYWORD: max_args += 1 if p.default is p.empty: min_args += 1 required_names.append(p.name) else: other_names.append(p.name) elif p.kind == p.KEYWORD_ONLY: other_names.append(p.name) elif p.kind == p.VAR_POSITIONAL: max_args = None elif p.kind == p.VAR_KEYWORD: other_names = any elif p.kind == p.POSITIONAL_ONLY: max_args += 1 if p.default is p.empty: min_args += 1 no_names = True if no_names: other_names = None return SignatureInfo(min_args, max_args, required_names, other_names) aiorpcX-0.24/aiorpcx/websocket.py000077500000000000000000000106071474217261100170650ustar00rootroot00000000000000# Copyright (c) 2019, Neil Booth # # All rights reserved. # # The MIT License (MIT) # # Permission is hereby granted, free of charge, to any person obtaining # a copy of this software and associated documentation files (the # "Software"), to deal in the Software without restriction, including # without limitation the rights to use, copy, modify, merge, publish, # distribute, sublicense, and/or sell copies of the Software, and to # permit persons to whom the Software is furnished to do so, subject to # the following conditions: # # The above copyright notice and this permission notice shall be # included in all copies or substantial portions of the Software. # # THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, # EXPRESS OR IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF # MERCHANTABILITY, FITNESS FOR A PARTICULAR PURPOSE AND # NONINFRINGEMENT. IN NO EVENT SHALL THE AUTHORS OR COPYRIGHT HOLDERS BE # LIABLE FOR ANY CLAIM, DAMAGES OR OTHER LIABILITY, WHETHER IN AN ACTION # OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION # WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. from functools import partial try: from websockets import connect, serve from websockets.exceptions import ConnectionClosed except ImportError: websockets = None from aiorpcx.curio import spawn from aiorpcx.session import RPCSession, SessionKind from aiorpcx.util import NetAddress __all__ = ('serve_ws', 'connect_ws') class WSTransport: '''Implementation of a websocket transport for session.py.''' def __init__(self, websocket, session_factory, kind): self.websocket = websocket self.kind = kind self.session = session_factory(self) self.closing = False @classmethod async def ws_server(cls, session_factory, websocket): transport = cls(websocket, session_factory, SessionKind.SERVER) await transport.process_messages() @classmethod async def ws_client(cls, uri, **kwargs): session_factory = kwargs.pop('session_factory', RPCSession) websocket = await connect(uri, **kwargs) return cls(websocket, session_factory, SessionKind.CLIENT) async def recv_message(self): message = await self.websocket.recv() # It might be nice to avoid the redundant conversions if isinstance(message, str): message = message.encode() self.session.data_received(message) return message async def process_messages(self): try: await self.session.process_messages(self.recv_message) except ConnectionClosed: pass # API exposed to session async def write(self, framed_message): # Prefer to send as text try: framed_message = framed_message.decode() except UnicodeDecodeError: pass await self.websocket.send(framed_message) async def close(self, _force_after=0): '''Close the connection and return when closed.''' self.closing = True await self.websocket.close() async def abort(self): '''Abort the connection. For now this just calls close().''' self.closing = True await self.close() def is_closing(self): '''Return True if the connection is closing.''' return self.closing def proxy(self): return None def remote_address(self): result = self.websocket.remote_address if result: result = NetAddress(*result[:2]) return result class WSClient: def __init__(self, uri, **kwargs): self.uri = uri self.session_factory = kwargs.pop('session_factory', RPCSession) self.kwargs = kwargs self.transport = None self.process_messages_task = None async def __aenter__(self): self.transport = await WSTransport.ws_client(self.uri, **self.kwargs) self.process_messages_task = await spawn(self.transport.process_messages()) return self.transport.session async def __aexit__(self, exc_type, exc_value, traceback): await self.transport.close() # Disabled this as loop might not have processed the event, and don't want to sleep here # assert self.process_messages_task.done() def serve_ws(session_factory, *args, **kwargs): ws_handler = partial(WSTransport.ws_server, session_factory) return serve(ws_handler, *args, **kwargs) connect_ws = WSClient aiorpcX-0.24/docs/000077500000000000000000000000001474217261100140015ustar00rootroot00000000000000aiorpcX-0.24/docs/Makefile000077500000000000000000000011341474217261100154430ustar00rootroot00000000000000# Minimal makefile for Sphinx documentation # # You can set these variables from the command line. SPHINXOPTS = SPHINXBUILD = sphinx-build SPHINXPROJ = aiorpcX SOURCEDIR = . BUILDDIR = build # Put it first so that "make" without argument is like "make help". help: @$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) .PHONY: help Makefile # Catch-all target: route all unknown targets to Sphinx using the new # "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS). %: Makefile @$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(O) aiorpcX-0.24/docs/changelog.rst000077500000000000000000000320401474217261100164640ustar00rootroot00000000000000ChangeLog ========= .. note:: The aiorpcX API changes regularly and is still unstable. I hope to finalize it for a 1.0 release in the coming months. Version 0.24 (16 Jan 2024) ---------------------------- * bump websockets library to >=14.0 and Python version to >=3.9 Version 0.23 (17 Mar 2024) ---------------------------- * fix packaging problem with 0.23 Version 0.23 (17 Mar 2024) ---------------------------- * TaskTimeout now derives from Exception not CancelledError * make several tests more robust * move to flake8 Version 0.22.1 (25 May 2021) ---------------------------- * release tasks as they complete in the task group; this might appear as a memory-leak for long-standing sessions Version 0.22.0 (25 Apr 2021) ---------------------------- * join() waits for all cancelled tasks to finish, including daemonic ones Version 0.21.1 (24 Apr 2021) ---------------------------- * handle peername of None in network code * strip redundant whitespace from JSON (SomberNight) Version 0.21.0 (11 Mar 2021) ---------------------------- * There have been significant semantic and API changes for TaskGroups. Their behaviour is now consistent, reliable and they have the same semantics as curio. As such I consider their API finalized and stable. In addition to the notes below for 0.20.x: * closed() became the attribute joined. * cancel_remaining() does not cancel daemonic tasks. As before it waits for the cancelled tasks to complete. * On return from join() all tasks including deamonic ones have been cancelled, but nothing is waited for. If leaving a TaskGroup context because of an exception, cancel_remaining() - which can block - is called before join(). Version 0.20.2 (10 Mar 2021) ---------------------------- * result, exception, results and exceptions are now attributes. They raise a RuntimeError if called before a TaskGroup's join() operation has returned. Version 0.20.1 (06 Mar 2021) ---------------------------- * this release contains some significant API changes which users will need to carefully check their code for. * the report_crash argument to spawn() is removed; instead a new one is named daemon. A daemon task's exception (if any) is ignored by a TaskGroup. * the join() method of TaskGroup (and so also when TaskGroup is used as a context manager) does not raise the exception of failed tasks. The full semantics are precisely described in the TaskGroup() docstring. Briefly: any task being cancelled or raising an exception causes join() to finish and all remaining tasks, including daemon tasks, to be cancelled. join() does not propagate task exceptions. * the cancel_remaining() method of TaskGroup does not propagate any task exceptions * TaskGroup supports the additional attributes 'tasks' and 'daemons'. Also, after join() has completed, result() returns the result (or raises the exception) of the first completed task. exception() returns the exception (if any) of the first completed task. results() returns the results of all tasks and exceptions() returns the exceptions raised by all tasks. daemon tasks are ignored. * The above changes bring the implementation in line with curio proper and the semantic changes it made over a year ago, and ensure that join() behaves consistently when called more than once. Version 0.18.4 (20 Nov 2019) ---------------------------- * handle time.time() not making progress. fixing `#26`_ (SomberNight) * handle SOCKSError in _connect_one (SomberNight) * add SOCKSRandomAuth: Jeremy Rand Version 0.18.3 (19 May 2019) ---------------------------- * minor bugfix release, fixing `#22`_ * make JSON IDs independent across sessions, make websockets dependency optional (SomberNight) Version 0.18.2 (19 May 2019) ---------------------------- * minor bugfix release Version 0.18.1 (09 May 2019) ---------------------------- * convert incoming websocket text frames to binary. Convert outgoing messages to text frames if possible. Version 0.18.0 (09 May 2019) ---------------------------- * Add *websocket* support as client and server by using Aymeric Augustin's excellent `websockets `_ package. Unfortunately this required changing several APIs. The code now distinguishes the previous TCP and SSL based-connections as *raw sockets* from the new websockets. The old Connector and Server classes are gone. Use `connect_rs()` and `serve_rs()` to connect a client and start a server for raw sockets; and `connect_ws()` and `serve_ws()` to do the same for websockets. SessionBase no longer inherits `asyncio.Protocol` as it is now transport-independent. Sessions no longer take a framer in their constructor: websocket messages are already framed, so instead a framer is passed to `connect_rs()` and `serve_rs()` if the default `NewlineFramer` is not wanted. A session is only instantiated when a connection handshake is completed, so `connection_made()` is no longer a method. `connection_lost()` and `abort()` are now coroutines; if overriding either be sure to call the base class implementation. `is_send_buffer_full()` was removed. * Updated and added new examples * JSON RPC message handling was made more efficient by using futures instead of events internally Version 0.17.0 (22 Apr 2019) ---------------------------- * Add some new APIs, update others * Add Service, NetAddress, ServicePart, validate_port, validate_protocol * SessionBase: new API proxy() and remote_address(). Remove peer_address() and peer_address_str() * SOCKSProxy: auto_detect_address(), auto_detect_host() renamed auto_detect_at_address() and auto_detect_at_host(). auto_detect_at_address() takes a NetAddress. Version 0.16.2 (21 Apr 2019) ---------------------------- * fix force-close bug Version 0.16.1 (20 Apr 2019) ---------------------------- * resolve socks proxy host using getaddrinfo. In particular, IPv6 is supported. * add two new APIs Version 0.16.0 (19 Apr 2019) ---------------------------- * session closing is now robust; it is safe to await session.close() from anywhere * API change: FinalRPCError removed; raise ReplyAndDisconnect instead. This responds with a normal result, or an error, and then disconnects. e.g.:: raise ReplyAndDisconnect(23) raise ReplyAndDisconnect(RPCError(1, "message")) * the session base class' private method _close() is removed. Use await close() instead. * workaround uvloop bug ``_ Version 0.15.0 (16 Apr 2019) ---------------------------- * error handling improved to include costing Version 0.14.1 (16 Apr 2019) ---------------------------- * fix a bad assertion Version 0.14.0 (15 Apr 2019) ---------------------------- * timeout handling improvements * RPCSession: add log_me, send_request_timeout * Concurrency: respect semaphore queue ordering * cleaner protocol auto-detection Version 0.13.6 (14 Apr 2019) ---------------------------- * RPCSession: concurrency control of outgoing requests to target a given response time * SessionBase: processing_timeout will time-out processing of incoming requests. This helps prevent ever-growing request backlogs. * SessionBase: add is_send_buffer_full() Version 0.13.5 (13 Apr 2019) ---------------------------- * robustify concurrency handling Version 0.13.3 (13 Apr 2019) ---------------------------- * export Concurrency class. Tweak some default constants. Version 0.13.2 (12 Apr 2019) ---------------------------- * wait for task to complete on close. Concurrency improvements. Version 0.13.0 (12 Apr 2019) ---------------------------- * fix concurrency handling; bump version as API changed Version 0.12.1 (09 Apr 2019) ---------------------------- * improve concurrency handling; expose new API Version 0.12.0 (09 Apr 2019) ---------------------------- * switch from bandwidth to a generic cost metric for sessions Version 0.11.0 (06 Apr 2019) ---------------------------- * rename 'normalize_corofunc' to 'instantiate_coroutine' * remove spawn() member of SessionBase * add FinalRPCError (ghost43) * more reliable cancellation on connection closing Version 0.10.5 (16 Feb 2019) ---------------------------- * export 'normalize_corofunc' * batches: fix handling of session loss; add test Version 0.10.4 (07 Feb 2019) ---------------------------- * SessionBase: add closed_event, tweak closing process * testsuite cleanup Version 0.10.3 (07 Feb 2019) ---------------------------- * NewlineFramer: max_size of 0 does not limit buffering (SomberNight) * trivial code / deprecation warning cleanups Version 0.10.2 (29 Dec 2018) ---------------------------- * TaskGroup: faster cancellation (SomberNight) * as for curio, remove wait argument to TaskGroup.join() * setup.py: read the file to extract the version; see `#10`_ Version 0.10.1 (07 Nov 2018) ---------------------------- * bugfixes for transport closing and session task spawning Version 0.10.0 (05 Nov 2018) ---------------------------- * add session.spawn() method * make various member variables private Version 0.9.1 (04 Nov 2018) --------------------------- * abort sessions which wait too long to send a message Version 0.9.0 (25 Oct 2018) --------------------------- * support of binary messaging and framing * support of plain messaging protocols. Messages do not have an ID and do not expect a response; any response cannot reference the message causing it as it has no ID (e.g. the Bitcoin network protocol). * removed the client / server session distinction. As a result there is now only a single session class for JSONRPC-style messaging, namely RPCSession, and a single session class for plain messaging protocols, MessageSession. Client connections are initiated by the session-independent Connector class. Version 0.8.2 (25 Sep 2018) --------------------------- * bw_limit defaults to 0 for ClientSession, bandwidth limiting is mainly intended for servers * don't close proxy sockets on an exception during the initial SOCKS handshake; see `#8`_. This works around an asyncio bug still present in Python 3.7 * make CodeMessageError hashable. This works around a Python bug fixed somewhere between Python 3.6.4 and 3.6.6 Version 0.8.1 (12 Sep 2018) --------------------------- * remove report_crash arguments from TaskGroup methods * ignore bandwidth limits if set <= 0 Version 0.8.0 (12 Sep 2018) --------------------------- * change TaskGroup semantics: the first error of a member task is raised by the TaskGroup instead of TaskGroupError (which is now removed). Code wanting to query the status / results of member tasks should loop on group.next_done(). Version 0.7.3 (17 Aug 2018) --------------------------- * fix `#5`_; more tests added Version 0.7.2 (16 Aug 2018) --------------------------- * Restore batch functionality in Session class * Less verbose logging * Increment and test error count on protocol errors * fix `#4`_ Version 0.7.1 (09 Aug 2018) --------------------------- * TaskGroup.cancel_remaining() must wait for the tasks to complete * Fix some tests whose success / failure depended on time races * fix `#3`_ Version 0.7.0 (08 Aug 2018) --------------------------- * Fix wait=object and cancellation * Change Session and JSONRPCConnection APIs * Fix a test that would hang on some systems Version 0.6.2 (06 Aug 2018) --------------------------- * Fix a couple of issues shown up by use in ElectrumX; add testcases Version 0.6.0 (04 Aug 2018) --------------------------- * Rework the API; docs are not yet updated * New JSONRPCConnection object that manages the state of a connection, replacing the RPCProcessor class. It hides the concept of request IDs from higher layers; allowing simpler and more intuitive RPC datastructures * The API now prefers async interfaces. In particular, request handlers must be async * The API generally throws exceptions earlier for nonsense conditions * TimeOut and TaskSet classes removed; use the superior curio primitives that 0.5.7 introduced instead * SOCKS protocol implementation made i/o agnostic so the code can be used whatever your I/O framework (sync, async, threads etc). The Proxy class, like the session class, remains asyncio * Testsuite cleaned up and shrunk, now works in Python 3.7 and also tests uvloop Version 0.5.9 (29 Jul 2018) --------------------------- * Remove "async" from __aiter__ which apparently breaks Python 3.7 Version 0.5.8 (28 Jul 2018) --------------------------- * Fix __str__ in TaskGroupError Version 0.5.7 (27 Jul 2018) --------------------------- * Implement some handy abstractions from curio on top of asyncio Version 0.5.6 ------------- * Define a ConnectionError exception, and set it on uncomplete requests when a connection is lost. Previously, those requests were cancelled, which does not give an informative error message. .. _#3: https://github.com/kyuupichan/aiorpcX/issues/3 .. _#4: https://github.com/kyuupichan/aiorpcX/issues/4 .. _#5: https://github.com/kyuupichan/aiorpcX/issues/5 .. _#8: https://github.com/kyuupichan/aiorpcX/issues/8 .. _#10: https://github.com/kyuupichan/aiorpcX/issues/10 .. _#22: https://github.com/kyuupichan/aiorpcX/issues/22 .. _#26: https://github.com/kyuupichan/aiorpcX/issues/26 aiorpcX-0.24/docs/conf.py000077500000000000000000000111541474217261100153050ustar00rootroot00000000000000# -*- coding: utf-8 -*- # # Configuration file for the Sphinx documentation builder. # # This file does only contain a selection of the most common options. For a # full list see the documentation: # http://www.sphinx-doc.org/en/stable/config # -- Path setup -------------------------------------------------------------- # If extensions (or modules to document with autodoc) are in another directory, # add these directories to sys.path here. If the directory is relative to the # documentation root, use os.path.abspath to make it absolute, like shown here. # import os import sys sys.path.insert(0, os.path.abspath('..')) import aiorpcx # -- Project information ----------------------------------------------------- project = 'aiorpcX' copyright = '2018, Neil Booth' author = 'Neil Booth' # The short X.Y version version = aiorpcx._version_str # The full version, including alpha/beta/rc tags release = version # -- General configuration --------------------------------------------------- # If your documentation needs a minimal Sphinx version, state it here. # # needs_sphinx = '1.0' # Add any Sphinx extension module names here, as strings. They can be # extensions coming with Sphinx (named 'sphinx.ext.*') or your custom # ones. extensions = [ ] # Add any paths that contain templates here, relative to this directory. templates_path = ['_templates'] # The suffix(es) of source filenames. # You can specify multiple suffix as a list of string: # # source_suffix = ['.rst', '.md'] source_suffix = '.rst' # The master toctree document. master_doc = 'index' # The language for content autogenerated by Sphinx. Refer to documentation # for a list of supported languages. # # This is also used if you do content translation via gettext catalogs. # Usually you set "language" from the command line for these cases. language = None # List of patterns, relative to source directory, that match files and # directories to ignore when looking for source files. # This pattern also affects html_static_path and html_extra_path . exclude_patterns = [] # The name of the Pygments (syntax highlighting) style to use. pygments_style = 'sphinx' # -- Options for HTML output ------------------------------------------------- # The theme to use for HTML and HTML Help pages. See the documentation for # a list of builtin themes. # html_theme = 'alabaster' # Theme options are theme-specific and customize the look and feel of a theme # further. For a list of options available for each theme, see the # documentation. # # html_theme_options = {} # Add any paths that contain custom static files (such as style sheets) here, # relative to this directory. They are copied after the builtin static files, # so a file named "default.css" will overwrite the builtin "default.css". # html_static_path = ['_static'] # Custom sidebar templates, must be a dictionary that maps document names # to template names. # # The default sidebars (for documents that don't match any pattern) are # defined by theme itself. Builtin themes are using these templates by # default: ``['localtoc.html', 'relations.html', 'sourcelink.html', # 'searchbox.html']``. # # html_sidebars = {} # -- Options for HTMLHelp output --------------------------------------------- # Output file base name for HTML help builder. htmlhelp_basename = 'aiorpcXdoc' # -- Options for LaTeX output ------------------------------------------------ latex_elements = { # The paper size ('letterpaper' or 'a4paper'). # # 'papersize': 'letterpaper', # The font size ('10pt', '11pt' or '12pt'). # # 'pointsize': '10pt', # Additional stuff for the LaTeX preamble. # # 'preamble': '', # Latex figure (float) alignment # # 'figure_align': 'htbp', } # Grouping the document tree into LaTeX files. List of tuples # (source start file, target name, title, # author, documentclass [howto, manual, or own class]). latex_documents = [ (master_doc, 'aiorpcX.tex', 'aiorpcX Documentation', 'Neil Booth', 'manual'), ] # -- Options for manual page output ------------------------------------------ # One entry per manual page. List of tuples # (source start file, name, description, authors, manual section). man_pages = [ (master_doc, 'aiorpcx', 'aiorpcX Documentation', [author], 1) ] # -- Options for Texinfo output ---------------------------------------------- # Grouping the document tree into Texinfo files. List of tuples # (source start file, target name, title, author, # dir menu entry, description, category) texinfo_documents = [ (master_doc, 'aiorpcX', 'aiorpcX Documentation', author, 'aiorpcX', 'One line description of project.', 'Miscellaneous'), ] aiorpcX-0.24/docs/framing.rst000077500000000000000000000027151474217261100161660ustar00rootroot00000000000000.. currentmodule:: aiorpcx Framing ======= Message :dfn:`framing` is the method by which RPC messages are wrapped in a byte stream so that message boundaries can be determined. :mod:`aiorpcx` provides an abstract base class for framers, and a single implementation: :class:`NewlineFramer`. A framer must know how to take outgoing messages and frame them, and also how to break an incoming byte stream into message frames in order to extract the RPC messages from it. .. class:: FramerBase Derive from this class to implement your own message framing methodology. .. method:: frame(messages) Frame each message and return the concatenated result. :param message: an iterable; each message should be of type :class:`bytes` or :class:`bytearray` :return: the concatenated bytestream :rtype: bytes .. method:: messages(data) :param data: incoming data of type :class:`bytes` or :class:`bytearray` :raises MemoryError: if the internal data buffer overflows .. note:: since this may raise an exception, the caller should process messages as they are yielded. Converting the messages to a list will lose earlier ones if an exception is raised later. .. class:: NewlineFramer(max_size=1000000) A framer where messages are delimited by an ASCII newline character in a text stream. The internal buffer for partial messages will hold up to *max_size* bytes. aiorpcX-0.24/docs/index.rst000077500000000000000000000032721474217261100156510ustar00rootroot00000000000000======= aiorpcX ======= .. image:: https://badge.fury.io/py/aiorpcX.svg :target: http://badge.fury.io/py/aiorpcX .. image:: https://travis-ci.org/kyuupichan/aiorpcX.svg?branch=master :target: https://travis-ci.org/kyuupichan/aiorpcX .. image:: https://coveralls.io/repos/github/kyuupichan/aiorpcX/badge.svg :target: https://coveralls.io/github/kyuupichan/aiorpcX A generic asyncio library implementation of RPC suitable for an application that is a client, server or both. The package includes a module with full coverage of `JSON RPC `_ versions 1.0 and 2.0, JSON RPC protocol auto-detection, and arbitrary message framing. It also comes with a SOCKS proxy client. The current version is |release|. The library API is not stable and may change radically. These docs are out of date and will be updated when the API settles. Source Code =========== The project is hosted on `GitHub `_. and uses `Travis `_ for Continuous Integration. Python version at least 3.6 is required. Please submit an issue on the `bug tracker `_ if you have found a bug or have a suggestion to improve the library. Authors and License =================== Neil Booth wrote the code, which is derived from the original JSON RPC code of `ElectrumX `_. The code is released under the `MIT Licence `_. Documentation ============= .. toctree:: changelog framing json-rpc rpc session socks Indices and tables ================== * :ref:`genindex` * :ref:`search` aiorpcX-0.24/docs/json-rpc.rst000077500000000000000000000063401474217261100162740ustar00rootroot00000000000000.. currentmodule:: aiorpcx JSON RPC ======== The :mod:`aiorpcx` module provides classes to interpret and construct JSON RPC protocol messages. Class instances are not used; all methods are class methods. Just call methods on the classes directly. .. class:: JSONRPC An abstract base class for concrete protocol classes. :class:`JSONRPCv1` and :class:`JSONRPCv2` are derived protocol classes implementing JSON RPC versions 1.0 and 2.0 in a strict way. .. class:: JSONRPCv1 A derived class of :class:`JSONRPC` implementing version 1.0 of the specification. .. class:: JSONRPCv2 A derived class of :class:`JSONRPC` implementing version 2.0 of the specification. .. class:: JSONRPCLoose A derived class of :class:`JSONRPC`. It accepts messages that conform to either version 1.0 or version 2.0. As it is loose, it will also accept messages that conform strictly to neither version. Unfortunately it is not possible to send messages that are acceptable to strict implementations of both versions 1.0 and 2.0, so it sends version 2.0 messages. .. class:: JSONRPCAutoDetect Auto-detects the JSON RPC protocol version spoken by the remote side based on the first incoming message, from :class:`JSONRPCv1`, :class:`JSONRPCv2` and :class:`JSONRPCLoose`. The RPC processor will then switch to that protocol version. Message interpretation ---------------------- .. classmethod:: JSONRPC.message_to_item(message) Convert a binary message into an RPC object describing the message and return it. :param bytes message: the message to interpret :return: the RPC object :rtype: :class:`RPCRequest`, :class:`RPCResponse` or :class:`RPCBatch`. If the message is ill-formed, return an :class:`RPCRequest` object with its :attr:`method` set to an :class:`RPCError` instance describing the error. Message construction -------------------- These functions convert an RPC item into a binary message that can be passed over the network after framing. .. classmethod:: JSONRPC.request_message(item) Convert a request item to a message. :param item: an :class:`RPCRequest` item :return: the message :rtype: bytes .. classmethod:: JSONRPC.response_message(item) Convert a response item to a message. :param item: an :class:`RPCResponse` item :return: the message :rtype: bytes .. classmethod:: JSONRPC.error_message(item) Convert an error item to a message. :param item: an :class:`RPCError` item :return: the message :rtype: bytes .. classmethod:: JSONRPC.batch_message(item) Convert a batch item to a message. :param item: an :class:`RPCBatch` item :return: the message :rtype: bytes .. classmethod:: JSONRPC.encode_payload(payload) Encode a Python object as a JSON string and convert it to bytes. If the object cannot be encoded as JSON, a JSON "internal error" error message is returned instead, with ID equal to the "id" member of `payload` if that is a dictionary, otherwise :const:`None`. :param payload: a Python object that can be represented as JSON. Numbers, strings, lists, dictionaries, :const:`True`, :const:`False` and :const:`None` are all valid. :return: a JSON message :rtype: bytes aiorpcX-0.24/docs/make.bat000077500000000000000000000014571474217261100154200ustar00rootroot00000000000000@ECHO OFF pushd %~dp0 REM Command file for Sphinx documentation if "%SPHINXBUILD%" == "" ( set SPHINXBUILD=sphinx-build ) set SOURCEDIR=source set BUILDDIR=build set SPHINXPROJ=aiorpcX if "%1" == "" goto help %SPHINXBUILD% >NUL 2>NUL if errorlevel 9009 ( echo. echo.The 'sphinx-build' command was not found. Make sure you have Sphinx echo.installed, then set the SPHINXBUILD environment variable to point echo.to the full path of the 'sphinx-build' executable. Alternatively you echo.may add the Sphinx directory to PATH. echo. echo.If you don't have Sphinx installed, grab it from echo.http://sphinx-doc.org/ exit /b 1 ) %SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% goto end :help %SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% :end popd aiorpcX-0.24/docs/rpc.rst000077500000000000000000000160641474217261100153310ustar00rootroot00000000000000.. currentmodule:: aiorpcx RPC items ========= The :mod:`aiorpcx` module defines some classes, instances of which will be returned by some of its APIs. You should not need to instantiate these objects directly. An instance of one of these classes is called an :dfn:`item`. .. class:: RPCRequest An RPC request or notification that has been received, or an outgoing notification. Outgoing requests are represented by :class:`RPCRequestOut` objects. .. attribute:: method The RPC method being invoked, a string. If an incoming request is ill-formed, so that, e.g., its method could not be determined, then this will be an :class:`RPCError` instance that describes the error. .. attribute:: args The arguments passed to the RPC method. This is a list or a dictionary, a dictionary if the arguments were passed by parameter name. .. attribute:: request_id The ID given to the request so that responses can be associated with requests. Normally an integer, or :const:`None` if the request is a :dfn:`notification`. Rarely it might be a floating point number or string. .. method:: is_notification() Returns :const:`True` if the request is a notification (its :attr:`request_id` is :const:`None`), otherwise :const:`False`. .. class:: RPCRequestOut An outgoing RPC request that is not a notification. A subclass of :class:`RPCRequest` and :class:`asyncio.Future `. When an outgoing request is created, typically via the :meth:`send_request` method of a client or server session, you can specify a callback to be called when the request is done. The callback is passed the request object, and the result can be obtained via its :meth:`result` method. A request can also be await-ed. Currently the result of await-ing is the same as calling :meth:`result` on the request but this may change in future. .. class:: RPCResponse An incoming or outgoing response. Outgoing response objects are automatically created by the framework when a request handler returns its result. .. attribute:: result The response result, a Python object. If an error occurred this will be an :class:`RPCError` object describing the error. .. attribute:: request_id The ID of the request this is a repsonse to. Notifications do not get responses so this will never be :const:`None`. If :attr:`result` in an :class:`RPCError` their :attr:`request_id` attributes will match. .. class:: RPCError Represents an error, either in an :class:`RPCResponse` object if an error occurred processing a request, or in a :class:`RPCRequest` if an incoming request was ill-formed. .. attribute:: message The error message as a string. .. attribute:: code The error code, an integer. .. attribute:: request_id The ID of the request that gave an error if it could be determined, otherwise :const:`None`. .. class:: RPCBatch Represents an incoming or outgoing RPC response batch, or an incoming RPC request batch. .. attribute:: items A list of the items in the batch. The list cannot be empty, and each item will be an :class:`RPCResponse` object for a response batch, and an :class:`RPCRequest` object for a request batch. Notifications and requests can be mixed together. Batches are iterable through their items, and taking their length returns the length of the items list. .. method:: requests A generator that yields non-notification items of a request batch, or each item for a response batch. .. method:: request_ids A *frozenset* of all request IDs in the batch, ignoring notifications. .. method:: is_request_batch Return :const:`True` if the batch is a request batch. .. class:: RPCBatchOut An outgoing RPC batch. A subclass of :class:`RPCBatch` and :class:`asyncio.Future `. When an outgoing request batch is created, typically via the :meth:`new_batch` method of a client or server session, you can specify a callback to be called when the batch is done. The callback is passed the batch object. Each non-notification item in an :class:`RPCBatchOut` object is itself an :class:`RPCRequestOut` object that can be independently waited on or cancelled. Notification items are :class:`RPCRequest` objects. Since batches are responded to as a whole, all member requests will be completed simultaneously. The order of callbacks of member requests, and of the batch itself, is unspecified. Cancelling a batch, or calling its :meth:`set_result` or :meth:`set_exception` methods cancels all its requests. .. method:: add_request(method, args=None, on_done=None) Add a request to the batch. A callback can be specified that will be called when the request completes. Returns the :class:`RPCRequestOut` request that was added to the batch. .. method:: add_notification(method, args=None) Add a notification to the batch. RPC Protocol Classes -------------------- RPC protocol classes should inherit from :class:`RPCProtocolBase`. The base class provides a few utility functions returning :class:`RPCError` objects. The derived class should redefine some constant class attributes. .. class:: RPCProtocolBase .. attribute:: INTERNAL_ERROR The integer error code to use for an internal error. .. attribute:: INVALID_ARGS The integer error code to use when an RPC request passes invalid arguments. .. attribute:: INVALID_REQUEST The integer error code to use when an RPC request is invalid. .. attribute:: METHOD_NOT_FOUND The integer error code to use when an RPC request is for a non-existent method. .. classmethod:: JSONRPC.internal_error(request_id) Return an :class:`RPCError` object with error code :attr:`INTERNAL_ERROR` for the given request ID. The error message will be ``"internal error processing request"``. :param request_id: the request ID, normally an integer or string :return: the error object :rtype: :class:`RPCError` .. classmethod:: JSONRPC.args_error(message) Return an :class:`RPCError` object with error code :attr:`INVALID_ARGS` with the given error message and a request ID of :const:`None`. :param str message: the error message :return: the error object :rtype: :class:`RPCError` .. classmethod:: JSONRPC.invalid_request(message, request_id=None) Return an :class:`RPCError` object with error code :attr:`INVALID_REQUEST` with the given error message and request ID. :param str message: the error message :param request_id: the request ID, normally an integer or string :return: the error object :rtype: :class:`RPCError` .. classmethod:: JSONRPC.method_not_found(message) Return an :class:`RPCError` object with error code :attr:`METHOD_NOT_FOUND` with the given error message and a request ID :const:`None`. :param str message: the error message :return: the error object :rtype: :class:`RPCError` aiorpcX-0.24/docs/session.rst000077500000000000000000000100071474217261100162170ustar00rootroot00000000000000.. currentmodule:: aiorpcx Exceptions ---------- .. exception:: ConnectionError When a connection is lost that has pending requests, this exception is set on those requests. Server ====== A simple wrapper around an :class:`asyncio.Server` object (see `asyncio.Server `_). .. class:: Server(protocol_factory, host=None, port=None, *, loop=None, \ **kwargs) Creates a server that listens for connections on *host* and *port*. The server does not actually start listening until :meth:`listen` is await-ed. *protocol_factory* is any callable returning an :class:`asyncio.Protocol` instance. You might find returning an instance of :class:`ServerSession`, or a class derived from it, more useful. *loop* is the event loop to use, or :func:`asyncio.get_event_loop()` if :const:`None`. *kwargs* are passed through to `loop.create_server()`_. A server instance has the following attributes: .. attribute:: loop The event loop being used. .. attribute:: host The host passed to the constructor .. attribute:: port The port passed to the constructor .. attribute:: server The underlying :class:`asyncio.Server` object when the server is listening, otherwise :const:`None`. .. method:: listen() Start listening for incoming connections. Return an :class:`asyncio.Server` instance, which can also be accessed via :attr:`server`. This method is a `coroutine`_. .. method:: close() Close the listening socket if the server is listening, and wait for it to close. Return immediately if the server is not listening. This does nothing to protocols and transports handling existing connections. On return :attr:`server` is :const:`None`. .. method:: wait_closed() Returns when the server has closed. This method is a `coroutine`_. Sessions ======== Convenience classes are provided for client and server sessions. .. class:: ClientSession(host, port, *, rpc_protocol=None, framer=None, \ scheduler=None, loop=None, proxy=None, **kwargs) An instance of an :class:`asyncio.Protocol` class that represents an RPC session with a remote server at *host* and *port*, as documented in `loop.create_connection()`_.` If *proxy* is not given, :meth:`create_connection` uses :meth:`loop.create_connection` to attempt a connection, otherwise :meth:`SOCKSProxy.create_connection`. You can pass additional arguments to those functions with *kwargs* (*host* and *port* and *loop* are used as given). *rpc_protocol* specifies the RPC protocol the server speaks. If :const:`None` the protocol returned by :meth:`default_rpc_protocol` is used. *framer* handles RPC message framing, and if :const:`None` then the framer returned by :meth:`default_framer` is used. *scheduler* should be left as :const:`None`. Logging will be sent to *logger*, :const:`None` will use a logger specific to the :class:`ClientSession` object's class. .. method:: create_connection() Make a connection attempt to the remote server. If successful this return a ``(transport, protocol)`` pair. This method is a `coroutine`_. .. method:: default_rpc_protocol() You can override this method to provide a default RPC protocol. :class:`JSONRPCv2` is returned by the default implementation. .. method:: default_framer() You can override this method to provide a default message frmaer. A new :class:`NewlineFramer` instance is returned by the default implementation. The :class:`ClientSession` and :class:`ServerSession` classes share a base class that has the following attributes and methods: .. _loop.create_connection(): https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.AbstractEventLoop.create_connection .. _loop.create_server(): https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.AbstractEventLoop.create_server .. _coroutine: https://docs.python.org/3/library/asyncio-task.html#coroutine aiorpcX-0.24/docs/socks.rst000077500000000000000000000124651474217261100156700ustar00rootroot00000000000000.. currentmodule:: aiorpcx SOCKS Proxy =========== The :mod:`aiorpcx` package includes a `SOCKS `_ proxy client. It understands the ``SOCKS4``, ``SOCKS4a`` and ``SOCKS5`` protocols. Exceptions ---------- .. exception:: SOCKSError The base class of SOCKS exceptions. Each raised exception will be an instance of a derived class. .. exception:: SOCKSProtocolError A subclass of :class:`SOCKSError`. Raised when the proxy does not follow the ``SOCKS`` protocol. .. exception:: SOCKSFailure A subclass of :class:`SOCKSError`. Raised when the proxy refuses or fails to make a connection. Authentication -------------- Currently the only supported authentication method is with a username and password. Usernames can be used by all SOCKS protocols, but only ``SOCKS5`` uses the password. .. class:: SOCKSUserAuth A :class:`namedtuple` for authentication with a SOCKS server. It has two members: .. attribute:: username A string. .. attribute:: password A string. Ignored by the :class:`SOCKS4` and :class:`SOCKS4a` protocols. Protocols --------- When creating a :class:`SocksProxy` object, a protocol must be specified and be one of the following. .. class:: SOCKS4 An abstract class representing the ``SOCKS4`` protocol. .. class:: SOCKS4a An abstract class representing the ``SOCKS4a`` protocol. .. class:: SOCKS5 An abstract class representing the ``SOCKS5`` protocol. Proxy ----- You can create a :class:`SOCKSProxy` object directly, but using one of its auto-detection class methods is likely more useful. .. class:: SOCKSProxy(address, protocol, auth) An object representing a SOCKS proxy. The address is a Python socket `address `_ typically a (host, port) pair for IPv4, and a (host, port, flowinfo, scopeid) tuple for IPv6. The *protocol* is one of :class:`SOCKS4`, :class:`SOCKS4a` and :class:`SOCKS5`. *auth* is a :class:`SOCKSUserAuth` object or :const:`None`. After construction, :attr:`host`, :attr:`port` and :attr:`peername` are set to :const:`None`. .. classmethod:: auto_detect_address(address, auth, \*, \ loop=None, timeout=5.0) Try to detect a SOCKS proxy at *address*. Protocols :class:`SOCKS5`, :class:`SOCKS4a` and :class:`SOCKS4` are tried in order. If a SOCKS proxy is detected return a :class:`SOCKSProxy` object, otherwise :const:`None`. Returning a proxy object only means one was detected, not that it is functioning - for example, it may not have full network connectivity. *auth* is a :class:`SOCKSUserAuth` object or :const:`None`. If testing any protocol takes more than *timeout* seconds, it is timed out and taken as not detected. This class method is a `coroutine`_. .. classmethod:: auto_detect_host(host, ports, auth, \*, \ loop=None, timeout=5.0) Try to detect a SOCKS proxy on *host* on one of the *ports*. Call :meth:`auto_detect_address` for each ``(host, port)`` pair until a proxy is detected, and return it, otherwise :const:`None`. *auth* is a :class:`SOCKSUserAuth` object or :const:`None`. If testing any protocol on any port takes more than *timeout* seconds, it is timed out and taken as not detected. This class method is a `coroutine`_. .. method:: create_connection(protocol_factory, host, port, \*, \ resolve=False, loop=None, ssl=None, family=0, proto=0, \ flags=0, timeout=30.0) Connect to (host, port) through the proxy in the background. When successful, the coroutine returns a ``(transport, protocol, address)`` triple, and sets the proxy attribute :attr:`peername`. * If *resolve* is :const:`True`, *host* is resolved locally rather than by the proxy. *family*, *proto*, *flags* are the optional address family, protocol and flags passed to `loop.getaddrinfo()`_ to get a list of remote addresses. If given, these should all be integers from the corresponding :mod:`socket` module constants. * *ssl* is as documented for `loop.create_connection()`_. If successfully connected the :attr:`_address` member of the protocol is set. If *resolve* is :const:`True` it is set to the successful address, otherwise ``(host, port)``. If connecting takes more than *timeout* seconds an :exc:`asyncio.TimeoutError` exception is raised. This method is a `coroutine`_. .. attribute:: host Set on a successful :meth:`create_connection` to the host passed to the proxy server. This will be the resolved address if its *resolve* argument was :const:`True`. .. attribute:: port Set on a successful :meth:`create_connection` to the host passed to the proxy server. .. attribute:: peername Set on a successful :meth:`create_connection` to the result of :meth:`socket.getpeername` on the socket connected to the proxy. .. _loop.create_connection(): https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.AbstractEventLoop.create_connection .. _loop.getaddrinfo(): https://docs.python.org/3/library/asyncio-eventloop.html#asyncio.AbstractEventLoop.getaddrinfo .. _coroutine: https://docs.python.org/3/library/asyncio-task.html#coroutine aiorpcX-0.24/examples/000077500000000000000000000000001474217261100146675ustar00rootroot00000000000000aiorpcX-0.24/examples/client_rs.py000077500000000000000000000027671474217261100172420ustar00rootroot00000000000000import asyncio import aiorpcx async def main(host, port): async with aiorpcx.connect_rs(host, port) as session: # A good request with standard argument passing result = await session.send_request('echo', ["Howdy"]) print(result) # A good request with named argument passing result = await session.send_request('echo', {'message': "Hello with a named argument"}) print(result) # aiorpcX transparently handles erroneous calls server-side, returning appropriate # errors. This in turn causes an exception to be raised in the client. for bad_args in ( ['echo'], ['echo', {}], ['foo'], # This causes an error running the server's buggy request handler. # aiorpcX catches the problem, returning an 'internal server error' to the # client, and continues serving ['sum', [2, 4, "b"]] ): try: await session.send_request(*bad_args) except Exception as exc: print(repr(exc)) # Batch requests async with session.send_batch() as batch: batch.add_request('echo', ["Me again"]) batch.add_notification('ping') batch.add_request('sum', list(range(50))) for n, result in enumerate(batch.results, start=1): print(f'batch result #{n}: {result}') asyncio.get_event_loop().run_until_complete(main('localhost', 8888)) aiorpcX-0.24/examples/client_us.py000077500000000000000000000027521474217261100172370ustar00rootroot00000000000000import asyncio import aiorpcx async def main(path): async with aiorpcx.connect_us(path) as session: # A good request with standard argument passing result = await session.send_request('echo', ["Howdy"]) print(result) # A good request with named argument passing result = await session.send_request('echo', {'message': "Hello with a named argument"}) print(result) # aiorpcX transparently handles erroneous calls server-side, returning appropriate # errors. This in turn causes an exception to be raised in the client. for bad_args in ( ['echo'], ['echo', {}], ['foo'], # This causes an error running the server's buggy request handler. # aiorpcX catches the problem, returning an 'internal server error' to the # client, and continues serving ['sum', [2, 4, "b"]] ): try: await session.send_request(*bad_args) except Exception as exc: print(repr(exc)) # Batch requests async with session.send_batch() as batch: batch.add_request('echo', ["Me again"]) batch.add_notification('ping') batch.add_request('sum', list(range(50))) for n, result in enumerate(batch.results, start=1): print(f'batch result #{n}: {result}') asyncio.get_event_loop().run_until_complete(main('/tmp/test.sock')) aiorpcX-0.24/examples/client_ws.py000077500000000000000000000030151474217261100172320ustar00rootroot00000000000000import asyncio import aiorpcx async def connect(uri): async with aiorpcx.connect_ws(uri) as session: # A good request with standard argument passing result = await session.send_request('echo', ["Howdy"]) print(result) # A good request with named argument passing result = await session.send_request('echo', {'message': "Hello with a named argument"}) print(result) # aiorpcX transparently handles erroneous calls server-side, returning appropriate # errors. This in turn causes an exception to be raised in the client. for bad_args in ( ['echo'], ['echo', {}], ['foo'], # This causes an error running the server's buggy request handler. # aiorpcX catches the problem, returning an 'internal server error' to the # client, and continues serving ['sum', [2, 4, "b"]] ): try: await session.send_request(*bad_args) except Exception as exc: print('Send reuest exception:', repr(exc)) # Batch requests async with session.send_batch() as batch: batch.add_request('echo', ["Me again"]) batch.add_notification('ping') batch.add_request('sum', list(range(50))) for n, result in enumerate(batch.results, start=1): print(f'batch result #{n}: {result}') asyncio.get_event_loop().run_until_complete(connect('ws://localhost:8889')) aiorpcX-0.24/examples/server_rs.py000077500000000000000000000017431474217261100172630ustar00rootroot00000000000000import asyncio import aiorpcx # Handlers are declared as normal python functions. aiorpcx automatically checks RPC # arguments, including named arguments, and returns errors as appropriate async def handle_echo(message): return message async def handle_sum(*values): return sum(values, 0) handlers = { 'echo': handle_echo, 'sum': handle_sum, } class ServerSession(aiorpcx.RPCSession): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) print(f'connection from {self.remote_address()}') async def connection_lost(self): await super().connection_lost() print(f'{self.remote_address()} disconnected') async def handle_request(self, request): handler = handlers.get(request.method) coro = aiorpcx.handler_invocation(handler, request)() return await coro loop = asyncio.get_event_loop() loop.run_until_complete(aiorpcx.serve_rs(ServerSession, 'localhost', 8888)) loop.run_forever() aiorpcX-0.24/examples/server_us.py000077500000000000000000000016521474217261100172650ustar00rootroot00000000000000import asyncio import aiorpcx # Handlers are declared as normal python functions. aiorpcx automatically checks RPC # arguments, including named arguments, and returns errors as appropriate async def handle_echo(message): return message async def handle_sum(*values): return sum(values, 0) handlers = { 'echo': handle_echo, 'sum': handle_sum, } class ServerSession(aiorpcx.RPCSession): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) print('connected') async def connection_lost(self): await super().connection_lost() print('disconnected') async def handle_request(self, request): handler = handlers.get(request.method) coro = aiorpcx.handler_invocation(handler, request)() return await coro loop = asyncio.get_event_loop() loop.run_until_complete(aiorpcx.serve_us(ServerSession, '/tmp/test.sock')) loop.run_forever() aiorpcX-0.24/examples/server_ws.py000077500000000000000000000017431474217261100172700ustar00rootroot00000000000000import asyncio import aiorpcx # Handlers are declared as normal python functions. aiorpcx automatically checks RPC # arguments, including named arguments, and returns errors as appropriate async def handle_echo(message): return message async def handle_sum(*values): return sum(values, 0) handlers = { 'echo': handle_echo, 'sum': handle_sum, } class ServerSession(aiorpcx.RPCSession): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) print(f'connection from {self.remote_address()}') async def connection_lost(self): await super().connection_lost() print(f'{self.remote_address()} disconnected') async def handle_request(self, request): handler = handlers.get(request.method) coro = aiorpcx.handler_invocation(handler, request)() return await coro loop = asyncio.get_event_loop() loop.run_until_complete(aiorpcx.serve_ws(ServerSession, 'localhost', 8889)) loop.run_forever() aiorpcX-0.24/pyproject.toml000066400000000000000000000014441474217261100157700ustar00rootroot00000000000000[build-system] requires = ["setuptools"] build-backend = "setuptools.build_meta" [project] name = 'aiorpcX' dynamic = ['version', 'readme'] requires-python = '>=3.9' dependencies = [] description = 'Generic async RPC implementation, including JSON-RPC' authors = [{name = 'Neil Booth', email = 'kyuupichan@pm.me'}] license = {file = "LICENCE"} urls = {'Project-URL' = 'https://github.com/kyuupichan/aiorpcX'} classifiers = [ 'Development Status :: 4 - Beta', 'Framework :: AsyncIO', 'Intended Audience :: Developers', 'License :: OSI Approved :: MIT License', 'Operating System :: OS Independent', "Programming Language :: Python :: 3.9", 'Topic :: Internet', "Topic :: Software Development :: Libraries :: Python Modules", ] [project.optional-dependencies] 'ws' = ['websockets>=14.0'] aiorpcX-0.24/pytest.ini000066400000000000000000000000351474217261100151000ustar00rootroot00000000000000[pytest] asyncio_mode = auto aiorpcX-0.24/setup.py000077500000000000000000000015311474217261100145660ustar00rootroot00000000000000import os.path import re import setuptools def find_version(filename): with open(filename) as f: text = f.read() match = re.search(r"^_version_str = '(.*)'$", text, re.MULTILINE) if not match: raise RuntimeError('cannot find version') return match.group(1) tld = os.path.abspath(os.path.dirname(__file__)) version = find_version(os.path.join(tld, 'aiorpcx', '__init__.py')) setuptools.setup( version=version, python_requires='>=3.9', packages=['aiorpcx'], # Tell setuptools to include data files specified by MANIFEST.in. include_package_data=True, download_url=('https://github.com/kyuupichan/aiorpcX/archive/' f'{version}.tar.gz'), long_description=( 'Transport, protocol and framing-independent async RPC ' 'client and server implementation. ' ), ) aiorpcX-0.24/tests/000077500000000000000000000000001474217261100142135ustar00rootroot00000000000000aiorpcX-0.24/tests/conftest.py000066400000000000000000000006741474217261100164210ustar00rootroot00000000000000# Pytest looks here for fixtures import asyncio import pytest try: import uvloop loop_params = (False, True) except ImportError: loop_params = (False, ) # This runs all the tests one with plain asyncio, then again with uvloop @pytest.fixture(scope="session", autouse=True, params=loop_params) def use_uvloop(request): if request.param: import uvloop asyncio.set_event_loop_policy(uvloop.EventLoopPolicy()) aiorpcX-0.24/tests/test_curio.py000077500000000000000000001233711474217261100167570ustar00rootroot00000000000000import asyncio from asyncio import get_event_loop import pytest from aiorpcx.curio import ( sleep, TaskGroup, spawn, CancelledError, Event, TaskTimeout, timeout_after, ignore_at, ignore_after, TimeoutCancellationError, UncaughtTimeoutError, run_in_thread, timeout_at, Lock, Queue, Semaphore, ) def sum_all(*values): return sum(values) async def my_raises(exc): raise exc async def return_value(x, secs=0): if secs: await sleep(secs) return x # Test exports sleep CancelledError Event Lock Queue Semaphore @pytest.mark.asyncio async def test_run_in_thread(): assert await run_in_thread(sum_all) == 0 assert await run_in_thread(sum_all, 1) == 1 assert await run_in_thread(sum_all, 1, 2, 3) == 6 @pytest.mark.asyncio async def test_next_done_1(): t = TaskGroup() assert t.completed is None assert await t.next_done() is None assert await t.next_done() is None @pytest.mark.asyncio async def test_next_done_2(): tasks = () t = TaskGroup(tasks) assert t.completed is None assert await t.next_done() is None await t.join() assert t.completed is None @pytest.mark.asyncio async def test_next_done_3(): tasks = (await spawn(sleep, 0.01), await spawn(sleep, 0.02)) t = TaskGroup(tasks) assert (await t.next_done(), await t.next_done()) == tasks assert await t.next_done() is None assert t.completed is None await t.join() assert t.completed is None assert await t.next_done() is None @pytest.mark.asyncio async def test_next_done_4(): tasks = (await spawn(sleep, 0), await spawn(sleep, 0.01)) tasks[0].cancel() await sleep(0) t = TaskGroup(tasks) assert (await t.next_done(), await t.next_done()) == tasks assert await t.next_done() is None @pytest.mark.asyncio async def test_next_done_5(): tasks = (await spawn(sleep(0.02)), await spawn(sleep, 0.01), await spawn(sleep, 0.03)) t = TaskGroup(tasks) assert await t.next_done() == tasks[1] assert await t.next_done() == tasks[0] await t.join() assert t.completed is tasks[2] @pytest.mark.asyncio async def test_next_done_6(): tasks = (await spawn(sleep, 0.02), await spawn(sleep, 0.01)) for task in tasks: task.cancel() t = TaskGroup(tasks) assert await t.next_done() == tasks[0] assert await t.next_done() == tasks[1] assert await t.next_done() is None @pytest.mark.asyncio async def test_next_deamons(): tasks = (await spawn(sleep, 0.1, daemon=True), await spawn(sleep, 0.001)) t = TaskGroup(tasks) assert await t.next_done() == tasks[1] assert not tasks[0].done() assert await t.next_done() is None assert not tasks[0].done() await tasks[0] @pytest.mark.asyncio async def test_next_result(): t = TaskGroup() with pytest.raises(RuntimeError): await t.next_result() tasks = () t = TaskGroup(tasks) with pytest.raises(RuntimeError): await t.next_result() tasks = (await spawn(return_value(1)), await spawn(return_value(2))) t = TaskGroup(tasks) assert (await t.next_result(), await t.next_result()) == (1, 2) with pytest.raises(RuntimeError): await t.next_result() @pytest.mark.asyncio async def test_tg_results_exceptions_good(): tasks = [ await spawn(return_value(1, 0.003)), await spawn(return_value(2, 0.002)), await spawn(return_value(3, 0.001)), ] async with TaskGroup(tasks, retain=True) as t: pass assert set(t.results) == {1, 2, 3} assert t.exceptions == [None] * 3 @pytest.mark.asyncio async def test_tg_results_exceptions_bad(): async with TaskGroup(retain=True) as t: task1 = await t.spawn(sleep, 1) await t.spawn(sleep, 2) await sleep(0.001) task1.cancel() with pytest.raises(CancelledError): t.results assert all(isinstance(e, CancelledError) for e in t.exceptions) @pytest.mark.asyncio async def test_tg_spawn(): t = TaskGroup() task = await t.spawn(sleep, 0.01) assert await t.next_done() == task assert await t.next_done() is None task = await t.spawn(sleep(0.01)) assert await t.next_done() == task @pytest.mark.asyncio async def test_tg_cancel_remaining(): tasks = [await spawn(sleep, secs, daemon=daemon) for secs, daemon in ((0.001, False), (0.2, True), (0.1, False), (0.1, False))] t = TaskGroup(tasks) assert await t.next_done() await t.cancel_remaining() assert not tasks[0].cancelled() # This is a daemon so is not cancelled assert not tasks[1].cancelled() assert tasks[2].cancelled() assert tasks[3].cancelled() assert not t.joined # join() cancels daemons await t.join() assert tasks[1].cancelled() assert t.joined @pytest.mark.asyncio async def test_tg_aiter(): tasks = [await spawn(sleep, x/200) for x in range(5, 0, -1)] t = TaskGroup(tasks) result = [task async for task in t] assert result == list(reversed(tasks)) @pytest.mark.asyncio async def test_tg_join_no_arg(): tasks = [await spawn(sleep, x/200) for x in range(5, 0, -1)] t = TaskGroup(tasks) await t.join() assert all(task.done() for task in tasks) assert not any(task.cancelled() for task in tasks) @pytest.mark.asyncio async def test_tg_cm_no_arg(): tasks = [await spawn(sleep, x) for x in (0.1, 0.01, -1)] async with TaskGroup(tasks) as t: pass assert all(task.done() for task in tasks) assert not any(task.cancelled() for task in tasks) assert t.completed is tasks[-1] @pytest.mark.asyncio async def test_tg_cm_all(): tasks = [await spawn(sleep, x/200) for x in range(5, 0, -1)] async with TaskGroup(tasks, wait=all) as t: pass assert all(task.done() for task in tasks) assert not any(task.cancelled() for task in tasks) assert t.completed is tasks[-1] @pytest.mark.asyncio async def test_tg_cm_none(): tasks = [await spawn(sleep, x/200) for x in range(1, 5)] async with TaskGroup(tasks, wait=None) as t: pass assert all(task.cancelled() for task in tasks) assert t.completed is None @pytest.mark.asyncio async def test_tg_cm_any(): tasks = [await spawn(sleep, x) for x in (0.1, 0.05, -1)] async with TaskGroup(tasks, wait=any) as t: pass assert all(task.done() for task in tasks) assert not tasks[-1].cancelled() assert all(task.cancelled() for task in tasks[:-1]) assert t.completed is tasks[-1] @pytest.mark.asyncio async def test_tg_join_object_1(): tasks = [await spawn(return_value(None, 0.01)), await spawn(return_value(3, 0.02))] t = TaskGroup(tasks, wait=object) await t.join() assert tasks[0].result() is None assert tasks[1].result() == 3 assert t.completed is tasks[1] assert t.result == 3 @pytest.mark.asyncio async def test_tg_join_object_2(): tasks = [await spawn(return_value(None, 0.01)), await spawn(return_value(4, 0.02)), await spawn(return_value(2, 2))] t = TaskGroup(tasks, wait=object) await t.join() assert t.completed is tasks[1] assert tasks[0].result() is None assert tasks[1].result() == 4 assert tasks[2].cancelled() @pytest.mark.asyncio async def test_tg_cm_object(): tasks = [await spawn(return_value(None, 0.01)), await spawn(return_value(3, 0.02))] async with TaskGroup(tasks, wait=object) as t: pass assert tasks[0].result() is None assert tasks[1].result() == 3 assert t.completed is tasks[1] tasks = [await spawn(return_value(None, 0.01)), await spawn(return_value(4, 0.02)), await spawn(return_value(2, 0.1))] async with TaskGroup(tasks, wait=object) as t: pass assert tasks[0].result() is None assert tasks[1].result() == 4 assert tasks[2].cancelled() assert t.completed is tasks[1] @pytest.mark.asyncio async def test_tg_join_errored(): for wait in (all, any, object): tasks = [await spawn(sleep, x/200) for x in range(5, 0, -1)] t = TaskGroup(tasks, wait=wait) bad_task = await t.spawn(my_raises(ArithmeticError)) await t.join() assert all(task.cancelled() for task in tasks) assert bad_task.done() and not bad_task.cancelled() assert t.completed is bad_task @pytest.mark.asyncio async def test_tg_cm_errored(): for wait in (all, any, object): tasks = [await spawn(sleep, x/200) for x in range(5, 0, -1)] async with TaskGroup(tasks, wait=wait) as t: bad_task = await t.spawn(my_raises(EOFError)) assert all(task.cancelled() for task in tasks) assert bad_task.done() and not bad_task.cancelled() assert t.completed is bad_task with pytest.raises(EOFError): t.result assert isinstance(t.exception, EOFError) @pytest.mark.asyncio async def test_tg_join_errored_past(): for wait in (all, any, object): tasks = [await spawn(my_raises, AttributeError) for n in range(3)] t = TaskGroup(tasks, wait=wait) tasks[1].cancel() await sleep(0.001) good_task = await t.spawn(return_value(3, 0.001)) await t.join() assert good_task.cancelled() assert t.completed is tasks[0] assert isinstance(t.exception, AttributeError) @pytest.mark.asyncio async def test_cm_join_errored_past(): for wait in (all, any, object): tasks = [await spawn(my_raises, BufferError) for n in range(3)] async with TaskGroup(tasks, wait=wait) as t: tasks[1].cancel() await sleep(0.001) good_task = await t.spawn(return_value(3, 0.001)) assert good_task.cancelled() assert t.completed is tasks[0] assert isinstance(t.exception, BufferError) @pytest.mark.asyncio async def test_cm_raises(): tasks = [await spawn(sleep, 0.01) for n in range(3)] with pytest.raises(ValueError): async with TaskGroup(tasks): raise ValueError assert all(task.cancelled() for task in tasks) @pytest.mark.asyncio async def test_cm_add_later(): tasks = [await spawn(sleep, 0) for n in range(3)] async with TaskGroup(tasks) as t: await sleep(0.001) await t.spawn(my_raises, LookupError) assert all(task.result() is None for task in tasks) assert t.completed in tasks assert t.result is None assert t.exception is None @pytest.mark.asyncio async def test_tg_multiple_groups(): task = await spawn(my_raises, FloatingPointError) TaskGroup([task]) with pytest.raises(RuntimeError): TaskGroup([task]) t3 = TaskGroup() with pytest.raises(RuntimeError): await t3.add_task(task) with pytest.raises(FloatingPointError): await task @pytest.mark.asyncio async def test_tg_joined(): task = await spawn(return_value(3)) for wait in (all, any, object): t = TaskGroup() assert not t.joined await t.join() assert t.joined with pytest.raises(RuntimeError): await t.spawn(my_raises, ImportError) with pytest.raises(RuntimeError): await t.add_task(task) await task @pytest.mark.asyncio async def test_tg_wait_bad(): tasks = [await spawn(sleep, x/200) for x in range(5, 0, -1)] with pytest.raises(ValueError): TaskGroup(tasks, wait=0) assert not any(task.cancelled() for task in tasks) for task in tasks: await task async def return_after_sleep(x, period=0.01): await sleep(period) return x @pytest.mark.asyncio async def test_timeout_after_coro_callstyles(): async def t1(*values): return 1 + sum(values) assert await timeout_after(0.01, t1) == 1 assert await timeout_after(0.01, t1()) == 1 assert await timeout_after(0.01, t1(2, 8)) == 11 assert await timeout_after(0.01, t1, 2, 8) == 11 coro = t1() with pytest.raises(ValueError): await timeout_after(0, coro, 1) await coro @pytest.mark.asyncio async def test_timeout_after_zero(): async def t1(*values): return 1 + sum(values) assert await timeout_after(0, t1) == 1 assert await timeout_after(0, t1, 2) == 3 assert await timeout_after(0, t1, 2, 8) == 11 @pytest.mark.asyncio async def test_timeout_after_no_expire(): async def t1(*values): return await return_after_sleep(1 + sum(values), 0.005) try: assert await timeout_after(0.1, t1, 1) == 2 except TaskTimeout: assert False assert True @pytest.mark.asyncio async def test_nested_after_no_expire_nested(): async def coro1(): pass async def child(): await timeout_after(0.001, coro1()) async def parent(): await timeout_after(0.003, child()) await parent() try: await sleep(0.005) except CancelledError: assert False @pytest.mark.asyncio async def test_nested_after_no_expire_nested2(): async def coro1(): pass async def child(): await timeout_after(0.001, coro1()) await sleep(0.005) async def parent(): try: await timeout_after(0.003, child()) except TaskTimeout: return assert False await parent() @pytest.mark.asyncio async def test_timeout_after_raises_IndexError(): try: await timeout_after(0.01, my_raises, IndexError) except IndexError: return assert False @pytest.mark.asyncio async def test_timeout_after_raises_CancelledError(): try: await timeout_after(0.01, my_raises, CancelledError) except CancelledError: return assert False @pytest.mark.asyncio async def test_nested_timeout(): results = [] async def coro1(): results.append('coro1 start') await sleep(1) results.append('coro1 done') async def coro2(): results.append('coro2 start') await sleep(1) results.append('coro2 done') # Parent should cause a timeout before the child. # Results in a TimeoutCancellationError instead of a normal TaskTimeout async def child(): try: await timeout_after(0.05, coro1()) results.append('coro1 success') except TaskTimeout: results.append('coro1 timeout') except TimeoutCancellationError: results.append('coro1 timeout cancel') await coro2() results.append('coro2 success') async def parent(): try: await timeout_after(0.01, child()) except TaskTimeout: results.append('parent timeout') await parent() assert results == [ 'coro1 start', 'coro1 timeout cancel', 'coro2 start', 'parent timeout' ] @pytest.mark.asyncio async def test_nested_context_timeout(): results = [] async def coro1(): results.append('coro1 start') await sleep(1) results.append('coro1 done') async def coro2(): results.append('coro2 start') await sleep(1) results.append('coro2 done') # Parent should cause a timeout before the child. # Results in a TimeoutCancellationError instead of a normal TaskTimeout async def child(): try: async with timeout_after(0.05) as ta: await coro1() results.append('coro1 success') except TaskTimeout: results.append('coro1 timeout') except TimeoutCancellationError: results.append('coro1 timeout cancel') assert not ta.expired await coro2() results.append('coro2 success') async def parent(): try: async with timeout_after(0.01) as ta: await child() except TaskTimeout: results.append('parent timeout') assert ta.expired await parent() assert results == [ 'coro1 start', 'coro1 timeout cancel', 'coro2 start', 'parent timeout' ] @pytest.mark.asyncio async def test_nested_context_timeout2(): async def coro1(): try: async with timeout_after(1) as ta: await sleep(5) except CancelledError as e: assert isinstance(e, TimeoutCancellationError) assert not ta.expired raise else: assert False async def coro2(): try: async with timeout_after(1.5) as ta: await coro1() except CancelledError as e: assert isinstance(e, TimeoutCancellationError) assert not ta.expired raise else: assert False async def parent(): try: async with timeout_after(0.01) as ta: await coro2() except (Exception, CancelledError) as e: assert isinstance(e, TaskTimeout) else: assert False assert ta.expired await parent() @pytest.mark.asyncio async def test_nested_context_timeout3(): async def coro1(): try: await timeout_after(1, sleep, 5) except CancelledError as e: assert isinstance(e, TimeoutCancellationError) raise else: assert False async def coro2(): try: await timeout_after(1.5, coro1) except CancelledError as e: assert isinstance(e, TimeoutCancellationError) raise else: assert False async def parent(): try: await timeout_after(0.001, coro2) except (Exception, CancelledError) as e: assert isinstance(e, TaskTimeout) else: assert False await parent() @pytest.mark.asyncio async def test_nested_timeout_again(): try: async with timeout_after(0.01): raise TaskTimeout(1.0) except TaskTimeout: pass @pytest.mark.asyncio async def test_nested_timeout_uncaught(): results = [] async def coro1(): results.append('coro1 start') await sleep(0.5) results.append('coro1 done') async def child(): # This will cause a TaskTimeout, but it's uncaught await timeout_after(0.001, coro1()) async def parent(): try: await timeout_after(1, child()) except TaskTimeout: results.append('parent timeout') except UncaughtTimeoutError: results.append('uncaught timeout') await parent() assert results == [ 'coro1 start', 'uncaught timeout' ] @pytest.mark.asyncio async def test_nested_context_timeout_uncaught(): results = [] async def coro1(): results.append('coro1 start') await sleep(0.5) results.append('coro1 done') async def child(): # This will cause a TaskTimeout, but it's uncaught async with timeout_after(0.001): await coro1() async def parent(): try: async with timeout_after(1): await child() except TaskTimeout: results.append('parent timeout') except UncaughtTimeoutError: results.append('uncaught timeout') await parent() assert results == [ 'coro1 start', 'uncaught timeout' ] @pytest.mark.asyncio async def test_nested_timeout_asyncio_wait_for(): async def foo(*, timeout=0.001): async with timeout_after(timeout): await sleep(10) # internal timeout 1 try: await asyncio.wait_for(foo(), None) assert False except TaskTimeout: pass except BaseException as e: assert False, e # internal timeout 2 try: await asyncio.wait_for(foo(), 2) assert False except TaskTimeout: pass except BaseException as e: assert False, e # external timeout try: await asyncio.wait_for(foo(timeout=2), 0.001) assert False except asyncio.TimeoutError: pass except BaseException as e: assert False, e @pytest.mark.asyncio async def test_nested_timeout_asyncio_ensure_future(): async def foo(*, timeout=0.001): async with timeout_after(timeout): await sleep(10) try: await asyncio.ensure_future(foo()) assert False except TaskTimeout: pass except BaseException as e: assert False, e @pytest.mark.asyncio async def test_nested_timeout_asyncio_create_task(): async def foo(*, timeout=0.001): async with timeout_after(timeout): await sleep(10) try: await asyncio.create_task(foo()) assert False except TaskTimeout: pass except BaseException as e: assert False, e @pytest.mark.asyncio async def test_timeout_at_time(): async def t1(*values): return 1 + sum(values) loop = get_event_loop() assert await timeout_at(loop.time(), t1) == 1 assert await timeout_at(loop.time(), t1, 2, 8) == 11 @pytest.mark.asyncio async def test_timeout_at_expires(): async def slow(): await sleep(0.02) return 2 loop = get_event_loop() try: await timeout_at(loop.time() + 0.001, slow) except TaskTimeout: return assert False @pytest.mark.asyncio async def test_timeout_at_context(): loop = get_event_loop() try: async with timeout_at(loop.time() + 0.001): await sleep(0.02) except TaskTimeout: return assert False # Ignore @pytest.mark.asyncio async def test_ignore_after_coro_callstyles(): async def t1(*values): return 1 + sum(values) assert await ignore_after(0.001, t1) == 1 assert await ignore_after(0.001, t1()) == 1 assert await ignore_after(0.001, t1(2, 8)) == 11 assert await ignore_after(0.001, t1, 2, 8) == 11 @pytest.mark.asyncio async def test_ignore_after_timeout_result(): async def t1(*values): await sleep(0.01) return 1 + sum(values) assert await ignore_after(0.005, t1, timeout_result=100) == 100 assert await ignore_after(0.005, t1, timeout_result=all) is all @pytest.mark.asyncio async def test_ignore_after_zero(): async def t1(*values): return 1 + sum(values) assert await ignore_after(0, t1) == 1 assert await ignore_after(0, t1, 2) == 3 assert await ignore_after(0, t1, 2, 8) == 11 @pytest.mark.asyncio async def test_ignore_after_no_expire(): async def t1(*values): return await return_after_sleep(1 + sum(values), 0.001) assert await ignore_after(0.1, t1, 1) == 2 await sleep(0.002) @pytest.mark.asyncio async def test_ignore_after_no_expire_nested(): async def coro1(): return 2 async def child(): return await ignore_after(0.001, coro1()) async def parent(): return await ignore_after(0.003, child()) try: result = await parent() await sleep(0.005) except Exception: assert False else: assert result == 2 @pytest.mark.asyncio async def test_ignore_after_no_expire_nested2(): async def coro1(): return 5 async def child(): result = await ignore_after(0.001, coro1(), timeout_result=1) await sleep(0.005) return result async def parent(): try: result = await ignore_after(0.003, child()) except Exception: assert False assert result is None await parent() @pytest.mark.asyncio async def test_ignore_after_raises_KeyError(): try: await ignore_after(0.01, my_raises, KeyError) except KeyError: return assert False @pytest.mark.asyncio async def test_ignore_after_raises_CancelledError(): try: await ignore_after(0.01, my_raises, CancelledError) except CancelledError: return assert False @pytest.mark.asyncio async def test_nested_ignore(): results = [] async def coro1(): results.append('coro1 start') await sleep(1) results.append('coro1 done') async def coro2(): results.append('coro2 start') await sleep(1) results.append('coro2 done') # Parent should cause a ignore before the child. # Results in a TimeoutCancellationError instead of a normal TaskTimeout async def child(): try: await ignore_after(0.005, coro1()) results.append('coro1 success') except TaskTimeout: results.append('coro1 timeout') except TimeoutCancellationError: results.append('coro1 timeout cancel') await coro2() results.append('coro2 success') async def parent(): try: await ignore_after(0.001, child()) results.append('parent success') except TaskTimeout: results.append('parent timeout') await parent() assert results == [ 'coro1 start', 'coro1 timeout cancel', 'coro2 start', 'parent success' ] @pytest.mark.asyncio async def test_nested_ignore_context_timeout(): results = [] async def coro1(): results.append('coro1 start') await sleep(1) results.append('coro1 done') async def coro2(): results.append('coro2 start') await sleep(1) results.append('coro2 done') # Parent should cause a timeout before the child. # Results in a TimeoutCancellationError instead of a normal ignore async def child(): try: async with ignore_after(0.005): await coro1() results.append('coro1 success') except TaskTimeout: results.append('coro1 timeout') except TimeoutCancellationError: results.append('coro1 timeout cancel') await coro2() results.append('coro2 success') async def parent(): try: async with ignore_after(0.001): await child() results.append('parent success') except TaskTimeout: results.append('parent timeout') await parent() assert results == [ 'coro1 start', 'coro1 timeout cancel', 'coro2 start', 'parent success' ] @pytest.mark.asyncio async def test_nested_ignore_context_timeout2(): async def coro1(): try: async with ignore_after(1): await sleep(5) assert False except CancelledError as e: assert isinstance(e, TimeoutCancellationError) raise async def coro2(): try: async with ignore_after(1.5): await coro1() assert False except CancelledError as e: assert isinstance(e, TimeoutCancellationError) raise async def parent(): try: async with ignore_after(0.001): await coro2() except Exception: assert False await parent() @pytest.mark.asyncio async def test_nested_ignore_context_timeout3(): async def coro1(): try: await ignore_after(1, sleep, 5) except CancelledError as e: assert isinstance(e, TimeoutCancellationError) raise else: assert False async def coro2(): try: await ignore_after(1.5, coro1) return 3 except CancelledError as e: assert isinstance(e, TimeoutCancellationError) raise else: assert False async def parent(): try: result = await ignore_after(0.001, coro2) except Exception: assert False else: assert result is None await parent() @pytest.mark.asyncio async def test_nested_ignore_timeout_uncaught(): results = [] async def coro1(): results.append('coro1 start') await sleep(0.5) results.append('coro1 done') async def child(): # This will do nothing await ignore_after(0.01, coro1()) results.append('coro1 ignored') return 1 async def parent(): try: if await ignore_after(0.02, child()) is None: results.append('child ignored') else: results.append('child succeeded') except TaskTimeout: results.append('parent timeout') except UncaughtTimeoutError: results.append('uncaught timeout') await parent() assert results == [ 'coro1 start', 'coro1 ignored', 'child succeeded' ] @pytest.mark.asyncio async def test_nested_ignore_context_timeout_uncaught(): results = [] async def coro1(): results.append('coro1 start') await sleep(0.05) results.append('coro1 done') async def child(): # This will be ignored async with ignore_after(0.001): await coro1() results.append('child succeeded') async def parent(): try: async with ignore_after(0.1): await child() results.append('parent succeeded') except TaskTimeout: results.append('parent timeout') except UncaughtTimeoutError: results.append('uncaught timeout') await parent() assert results == [ 'coro1 start', 'child succeeded', 'parent succeeded' ] @pytest.mark.asyncio async def test_ignore_at_time(): async def t1(*values): return 1 + sum(values) loop = get_event_loop() assert await ignore_at(loop.time(), t1) == 1 assert await ignore_at(loop.time(), t1, 2, 8) == 11 @pytest.mark.asyncio async def test_ignore_at_expires(): async def slow(): await sleep(0.02) return 2 loop = get_event_loop() try: result = await ignore_at(loop.time() + 0.001, slow()) except Exception: assert False assert result is None try: result = await ignore_at(loop.time() + 0.001, slow, timeout_result=1) except Exception: assert False assert result == 1 @pytest.mark.asyncio async def test_ignore_at_context(): loop = get_event_loop() try: async with ignore_at(loop.time() + 0.001): await sleep(0.02) assert False except Exception: assert False # # Task group tests snitched from curio # @pytest.mark.asyncio async def test_task_group(): async def child(x, y): return x + y async def main(): async with TaskGroup() as g: t1 = await g.spawn(child, 1, 1) t2 = await g.spawn(child, 2, 2) t3 = await g.spawn(child, 3, 3) assert t1.result() == 2 assert t2.result() == 4 assert t3.result() == 6 await main() @pytest.mark.asyncio async def test_task_group_existing(): evt = Event() async def child(x, y): return x + y async def child2(x, y): await evt.wait() return x + y async def main(): t1 = await spawn(child, 1, 1) t2 = await spawn(child2, 2, 2) t3 = await spawn(child2, 3, 3) t4 = await spawn(child, 4, 4) await t1 await t4 async with TaskGroup([t1, t2, t3]) as g: evt.set() await g.add_task(t4) assert t1.result() == 2 assert t2.result() == 4 assert t3.result() == 6 assert t4.result() == 8 await main() @pytest.mark.asyncio async def test_task_any_cancel(): evt = Event() async def child(x, y): return x + y async def child2(x, y): await evt.wait() return x + y async def main(): async with TaskGroup(wait=any) as g: t1 = await g.spawn(child, 1, 1) t2 = await g.spawn(child2, 2, 2) t3 = await g.spawn(child2, 3, 3) assert t1.result() == 2 assert t1 == g.completed assert g.result == 2 assert g.exception is None assert t2.cancelled() assert t3.cancelled() await main() @pytest.mark.asyncio async def test_task_any_error(): evt = Event() async def child(x, y): return x + y async def child2(x, y): await evt.wait() return x + y async def main(): async with TaskGroup(wait=any) as g: t1 = await g.spawn(child, 1, '1') t2 = await g.spawn(child2, 2, 2) t3 = await g.spawn(child2, 3, 3) assert isinstance(t1.exception(), TypeError) assert g.completed is t1 with pytest.raises(TypeError): g.result assert g.exception is t1.exception() assert t2.cancelled() assert t3.cancelled() await main() @pytest.mark.asyncio async def test_task_group_iter(): async def child(x, y): return x + y async def main(): results = set() async with TaskGroup() as g: await g.spawn(child, 1, 1) await g.spawn(child, 2, 2) await g.spawn(child, 3, 3) async for task in g: results.add(task.result()) assert results == {2, 4, 6} await main() @pytest.mark.asyncio async def test_task_group_error(): evt = Event() async def child(x, y): x + y await evt.wait() async def main(): async with TaskGroup() as g: t1 = await g.spawn(child, 1, 1) t2 = await g.spawn(child, 2, 2) t3 = await g.spawn(child, 3, 'bad') assert g.completed is t3 assert g.exception == t3.exception() assert t1.cancelled() assert t2.cancelled() await main() @pytest.mark.asyncio async def test_task_group_error_block(): evt = Event() async def child(x, y): await evt.wait() async def main(): try: async with TaskGroup() as g: t1 = await g.spawn(child, 1, 1) t2 = await g.spawn(child, 2, 2) t3 = await g.spawn(child, 3, 3) raise RuntimeError() except RuntimeError: assert True else: assert False assert t1.cancelled() assert t2.cancelled() assert t3.cancelled() await main() @pytest.mark.asyncio async def test_task_group_multierror(): evt = Event() async def child(exctype): if exctype: raise exctype('Died') await evt.wait() async def main(): async with TaskGroup() as g: t1 = await g.spawn(child, RuntimeError) t2 = await g.spawn(child, MemoryError) await g.spawn(child, None) await sleep(0) evt.set() assert isinstance(t1.exception(), RuntimeError) assert isinstance(t2.exception(), MemoryError) await main() @pytest.mark.asyncio async def test_task_group_cancel(): evt = Event() evt2 = Event() async def child(): try: await evt.wait() except CancelledError: assert True raise else: raise False async def coro(): try: async with TaskGroup() as g: t1 = await g.spawn(child) t2 = await g.spawn(child) t3 = await g.spawn(child) evt2.set() except CancelledError: assert t1.cancelled() assert t2.cancelled() assert t3.cancelled() raise else: assert False async def main(): t = await spawn(coro) await evt2.wait() t.cancel() try: await t except CancelledError: pass await main() @pytest.mark.asyncio async def test_task_group_timeout(): evt = Event() async def child(): try: await evt.wait() except CancelledError: assert True raise else: raise False async def coro(): try: async with timeout_after(0.01): try: async with TaskGroup() as g: t1 = await g.spawn(child) t2 = await g.spawn(child) t3 = await g.spawn(child) except CancelledError: assert t1.cancelled() assert t2.cancelled() assert t3.cancelled() raise except TaskTimeout: assert True else: assert False await coro() @pytest.mark.asyncio async def test_task_group_cancel_remaining(): evt = Event() async def child(x, y): return x + y async def waiter(): await evt.wait() async def main(): async with TaskGroup() as g: t1 = await g.spawn(child, 1, 1) t2 = await g.spawn(waiter) t3 = await g.spawn(waiter) t = await g.next_done() assert t == t1 await g.cancel_remaining() assert t2.cancelled() assert t3.cancelled() await main() @pytest.mark.asyncio async def test_task_group_cancel_remaining_waits(): async def sleep_soundly(): try: await sleep(0.01) except CancelledError: await sleep(0.01) task = await spawn(sleep_soundly) with pytest.raises(CancelledError): async with TaskGroup([task]): await sleep(0) # ensure the tasks are scheduled raise CancelledError # Exiting the context with an exception (here, CancelledError) waits for non-daemonic tasks # to finish assert task.done() @pytest.mark.asyncio async def test_task_group_cancel_remaining_daemonic_waits(): async def sleep_soundly(): try: await sleep(0.01) except CancelledError: await sleep(0.01) task = await spawn(sleep_soundly, daemon=True) with pytest.raises(CancelledError): async with TaskGroup([task]): await sleep(0) # ensure the tasks are scheduled raise CancelledError # The task is daemonic but is still waited for. assert task.done() assert not task.cancelled() # Didn't raise CancelledError @pytest.mark.asyncio async def test_task_group_use_error(): async def main(): async with TaskGroup() as g: t1 = await g.spawn(sleep, 0) with pytest.raises(RuntimeError): await g.add_task(t1) with pytest.raises(RuntimeError): await g.spawn(sleep, 0) t2 = await spawn(sleep, 0) with pytest.raises(RuntimeError): await g.add_task(t2) await t2 await main() @pytest.mark.asyncio async def test_task_group_cancel_task(): for wait in (all, object, any): async with TaskGroup(wait=object) as g: task1 = await g.spawn(sleep, 1) task2 = await g.spawn(sleep, 2) await sleep(0.001) task1.cancel() assert task1.cancelled() assert task2.cancelled() assert g.completed is task1 assert isinstance(g.exception, CancelledError) with pytest.raises(CancelledError): g.result() @pytest.mark.asyncio async def test_task_group_cancel_task2(): async with TaskGroup(wait=None) as g: task1 = await g.spawn(sleep, 1) task2 = await g.spawn(sleep, 2) assert task1.cancelled() assert task2.cancelled() assert g.completed is None assert g.exception is None with pytest.raises(RuntimeError): assert g.result is None @pytest.mark.asyncio async def test_task_group_bad_result_exception(): async with TaskGroup(wait=None) as g: task1 = await g.spawn(sleep, 1) await sleep(0.001) with pytest.raises(RuntimeError): g.result with pytest.raises(RuntimeError): g.exception with pytest.raises(RuntimeError): g.results with pytest.raises(RuntimeError): g.exceptions task1.cancel() @pytest.mark.asyncio async def test_daemon_tasks_not_waited_for_and_cancelled(): evt = Event() async def wait_forever(): await evt.wait() async with TaskGroup() as g: d = await g.spawn(wait_forever, daemon=True) t = await g.spawn(return_value, 5, 0.005) assert g.tasks == {t} assert g.daemons == {d} assert d.cancelled() assert g.result == 5 assert g.exception is None @pytest.mark.asyncio async def test_daemon_task_errors_ignored(): async with TaskGroup() as g: d = await g.spawn(my_raises(ArithmeticError), daemon=True) t = await g.spawn(return_value, 5, 0.005) assert g.tasks == {t} assert g.daemons == {d} await sleep(0.01) assert g.result == 5 assert g.exception is None # See https://github.com/kyuupichan/aiorpcX/issues/37 @pytest.mark.asyncio async def test_cancel_remaining_on_group_with_stubborn_task(): evt = Event() async def run_forever(): while True: try: await evt.wait() break except CancelledError: pass async def run_group(): async with group: await group.spawn(run_forever) from asyncio import create_task group = TaskGroup() create_task(run_group()) await sleep(0.01) try: async with timeout_after(0.01): await group.cancel_remaining() except TaskTimeout: pass # Clean teardown evt.set() await sleep(0.001) # See https://github.com/kyuupichan/aiorpcX/issues/46 @pytest.mark.asyncio async def test_tasks_pop(): delay = 0.05 N = 10 async def finish_quick(): await sleep(delay / 2) async with TaskGroup() as group: await group.spawn(finish_quick) assert len(group.tasks) await sleep(delay) assert not len(group.tasks) async with TaskGroup() as group: for n in range(N): await group.spawn(finish_quick) await group.spawn(finish_quick, daemon=True) assert len(group.tasks) == N assert not len(group.tasks) task1 = await spawn(finish_quick) task2 = await spawn(finish_quick, daemon=True) async with TaskGroup((task1, task2)) as group: assert len(group.tasks) == 1 await sleep(delay) assert not len(group.tasks) def test_TaskTimeout_str(): t = TaskTimeout(0.5) assert str(t) == 'task timed out after 0.5s' aiorpcX-0.24/tests/test_framing.py000077500000000000000000000150011474217261100172470ustar00rootroot00000000000000import os import random import pytest from aiorpcx import ( BinaryFramer, BitcoinFramer, OversizedPayloadError, BadMagicError, BadChecksumError, FramerBase, NewlineFramer, TaskGroup, sleep, timeout_after, ) from aiorpcx.framing import ByteQueue @pytest.mark.asyncio async def test_FramerBase(): framer = FramerBase() with pytest.raises(NotImplementedError): framer.received_bytes(b'') with pytest.raises(NotImplementedError): await framer.receive_message() with pytest.raises(NotImplementedError): framer.frame(b'') with pytest.raises(NotImplementedError): framer.frame(TypeError) def test_NewlineFramer_framing(): framer = NewlineFramer() assert framer.frame(b'foo') == b'foo\n' @pytest.mark.asyncio async def test_NewlineFramer_messages(): framer = NewlineFramer() framer.received_bytes(b'abc\ndef\ngh') assert await framer.receive_message() == b'abc' assert await framer.receive_message() == b'def' async def receive_message(): return await framer.receive_message() async def put_rest(): await sleep(0.001) framer.received_bytes(b'i\n') async with TaskGroup() as group: task = await group.spawn(receive_message) await group.spawn(put_rest) assert task.result() == b'ghi' @pytest.mark.asyncio async def test_NewlineFramer_overflow(): framer = NewlineFramer(max_size=5) framer.received_bytes(b'abcde\n') assert await framer.receive_message() == b'abcde' framer.received_bytes(b'abcde') framer.received_bytes(b'f') with pytest.raises(MemoryError): await framer.receive_message() # Resynchronizes to next \n, returns b'yz' and stores 'AB' framer.received_bytes(b'ghijklmnopqrstuvwx\nyz\nAB') assert await framer.receive_message() == b'yz' # Add 'C' framer.received_bytes(b'C') async with TaskGroup() as group: task = await group.spawn(framer.receive_message()) await sleep(0.001) framer.received_bytes(b'DEFGHIJKL\nYZ') # Accepts over-sized message as doesn't need to store it assert task.result() == b'ABCDEFGHIJKL' framer.received_bytes(b'\n') assert await framer.receive_message() == b'YZ' framer = NewlineFramer(max_size=0) framer.received_bytes(b'abc') async with TaskGroup() as group: task = await group.spawn(framer.receive_message()) await sleep(0.001) framer.received_bytes(b'\n') assert task.result() == b'abc' @pytest.mark.asyncio async def test_ByteQueue(): bq = ByteQueue() lengths = [random.randrange(0, 15) for n in range(40)] data = os.urandom(sum(lengths)) answer = [] cursor = 0 for length in lengths: answer.append(data[cursor: cursor + length]) cursor += length assert b''.join(answer) == data async def putter(): cursor = 0 while cursor < len(data): size = random.randrange(0, min(15, len(data) - cursor + 1)) bq.put_nowait(data[cursor:cursor + size]) cursor += size await sleep(random.random() * 0.005) async def getter(): result = [] for length in lengths: item = await bq.receive(length) result.append(item) await sleep(random.random() * 0.005) return result async with timeout_after(1): async with TaskGroup() as group: await group.spawn(putter) gjob = await group.spawn(getter) assert gjob.result() == answer assert bq.parts == [b''] assert bq.parts_len == 0 class TestBitcoinFramer(): def test_framing(self): framer = BitcoinFramer() result = framer.frame((b'version', b'payload')) assert result == b'\xe3\xe1\xf3\xe8version\x00\x00\x00\x00\x00' \ b'\x07\x00\x00\x00\xe7\x871\xbbpayload' @pytest.mark.asyncio async def test_not_implemented(self): framer = BinaryFramer() with pytest.raises(NotImplementedError): framer._checksum(b'') with pytest.raises(NotImplementedError): framer._build_header(b'', b'') with pytest.raises(NotImplementedError): await framer._receive_header() def test_oversized_command(self): framer = BitcoinFramer() with pytest.raises(ValueError): framer._build_header(bytes(13), b'') @pytest.mark.asyncio async def test_oversized_message(self): framer = BitcoinFramer() framer.max_payload_size = 2000 framer._max_block_size = 10000 header = framer._build_header(b'', bytes(framer.max_payload_size)) framer.received_bytes(header) await framer._receive_header() header = framer._build_header(b'', bytes(framer.max_payload_size + 1)) framer.received_bytes(header) with pytest.raises(OversizedPayloadError): await framer._receive_header() header = framer._build_header(b'block', bytes(framer._max_block_size)) framer.received_bytes(header) await framer._receive_header() header = framer._build_header(b'block', bytes(framer._max_block_size + 1)) framer.received_bytes(header) with pytest.raises(OversizedPayloadError): await framer._receive_header() @pytest.mark.asyncio async def test_receive_message(self): framer = BitcoinFramer() result = framer.frame((b'version', b'payload')) framer.received_bytes(result) command, payload = await framer.receive_message() assert command == b'version' assert payload == b'payload' @pytest.mark.asyncio async def test_bad_magic(self): framer = BitcoinFramer() good_msg = framer.frame((b'version', b'payload')) pos = random.randrange(0, 24) for n in range(4): msg = bytearray(good_msg) msg[n] ^= 1 framer.received_bytes(msg[:pos]) # Just header should trigger the error framer.received_bytes(msg[pos:24]) with pytest.raises(BadMagicError): await framer.receive_message() @pytest.mark.asyncio async def test_bad_checksum(self): framer = BitcoinFramer() good_msg = framer.frame((b'version', b'payload')) pos = random.randrange(0, len(good_msg)) for n in range(20, 24): msg = bytearray(good_msg) msg[n] ^= 1 framer.received_bytes(msg[:pos]) framer.received_bytes(msg[pos:]) with pytest.raises(BadChecksumError): await framer.receive_message() aiorpcX-0.24/tests/test_jsonrpc.py000077500000000000000000001172531474217261100173160ustar00rootroot00000000000000import sys from itertools import combinations, count import json import pytest from aiorpcx import ( Request, handler_invocation, Queue, ignore_after, ProtocolError, TaskGroup, Batch, Notification, JSONRPCConnection, JSONRPC, JSONRPCv1, JSONRPCv2, JSONRPCLoose, RPCError, JSONRPCAutoDetect, timeout_after, Event, ) from aiorpcx.jsonrpc import Response, CodeMessageError from util import assert_RPCError, assert_ProtocolError from random import shuffle def assert_is_error_response(item, text, code): assert isinstance(item, Response) item = item.result assert isinstance(item, RPCError) assert item.code == code assert text in item.message def assert_is_request(item, method, args): assert isinstance(item, Request) assert item.method == method assert item.args == args def assert_is_notification(item, method, args): assert isinstance(item, Notification) assert item.method == method assert item.args == args def assert_is_good_response(item, result): assert isinstance(item, Response) assert item.result == result def canonical_message(protocol, payload): payload = payload.copy() if protocol == JSONRPCv2: payload['jsonrpc'] = '2.0' elif protocol == JSONRPCv1: if 'method' in payload and 'params' not in payload: payload['params'] = [] if 'error' in payload and 'result' not in payload: payload['result'] = None if 'result' in payload and 'error' not in payload: payload['error'] = None return json.dumps(payload).encode() def payload_to_item(protocol, payload): return protocol.message_to_item(canonical_message(protocol, payload)) @pytest.fixture(params=(JSONRPCv1, JSONRPCv2, JSONRPCLoose, JSONRPCAutoDetect)) def protocol(request): return request.param @pytest.fixture(params=(JSONRPCv1, JSONRPCv2, JSONRPCLoose)) def protocol_no_auto(request): return request.param @pytest.fixture(params=(JSONRPCLoose, JSONRPCv2)) def batch_protocol(request): return request.param # MISC def test_abstract(): class MyProtocol(JSONRPC): pass with pytest.raises(NotImplementedError): MyProtocol._message_id({}, True) with pytest.raises(NotImplementedError): MyProtocol._request_args({}) def test_exception_is_hashable(): hash(CodeMessageError(0, '')) # see if raises # ENCODING def test_parse_errors(protocol_no_auto): protocol = protocol_no_auto # Bad encoding message = b'123\xff' with pytest.raises(ProtocolError) as e: protocol.message_to_item(message) assert e.value.code == JSONRPC.PARSE_ERROR assert 'messages must be encoded in UTF-8' in e.value.message assert b'"id":null' in e.value.error_message # Bad JSON message = b'{"foo",}' with pytest.raises(ProtocolError) as e: protocol.message_to_item(message) assert e.value.code == JSONRPC.PARSE_ERROR assert 'invalid JSON' in e.value.message assert b'"id":null' in e.value.error_message messages = [b'2', b'"foo"', b'2.78'] for message in messages: with pytest.raises(ProtocolError) as e: protocol.message_to_item(message) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'must be a dictionary' in e.value.message assert b'"id":null' in e.value.error_message # Requests def test_request(): for bad_method in (None, 2, b'', [2], {}): with pytest.raises(ProtocolError) as e: Request(bad_method, []) assert e.value.code == JSONRPC.METHOD_NOT_FOUND assert 'must be a string' in e.value.message with pytest.raises(ProtocolError) as e: Notification(bad_method, []) assert e.value.code == JSONRPC.METHOD_NOT_FOUND assert 'must be a string' in e.value.message for bad_args in (2, "foo", None, False): with pytest.raises(ProtocolError) as e: Request('method', bad_args) assert e.value.code == JSONRPC.INVALID_ARGS assert 'arguments' in e.value.message with pytest.raises(ProtocolError) as e: Notification('', bad_args) assert e.value.code == JSONRPC.INVALID_ARGS assert 'arguments' in e.value.message assert repr(Request('m', [2])) == "Request('m', [2])" assert repr(Request('m', [])) == "Request('m', [])" assert repr(Request('m', {})) == "Request('m', {})" assert repr(Request('m', {"a": 0})) == "Request('m', {'a': 0})" def test_Batch(): b = Batch([Request("m", []), Request("n", [])]) assert repr(b) == "Batch(2 items)" with pytest.raises(ProtocolError) as e: Batch([Request('m', []), Response(2)]) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'homogeneous' in e.value.message with pytest.raises(ProtocolError) as e: Batch([b]) assert e.value.code == JSONRPC.INVALID_REQUEST with pytest.raises(ProtocolError) as e: Batch(2) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'must be a list' in e.value.message with pytest.raises(ProtocolError) as e: Batch((x for x in (1, ))) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'must be a list' in e.value.message assert b[:2] == b.items[:2] def test_JSONRPCv1_ill_formed(): protocol = JSONRPCv1 # Named arguments payloads = [ {"method": "a", "params": {}, "id": 123}, {"method": "a", "params": {"a": 1, "b": "c"}, "id": 123}, ] for payload in payloads: message = canonical_message(protocol, payload) with pytest.raises(ProtocolError) as e: protocol.message_to_item(message) assert e.value.code == JSONRPC.INVALID_ARGS assert 'invalid request arguments' in e.value.message assert b'123' in e.value.error_message request = Request('a', {"a": 1}) with pytest.raises(ProtocolError) as e: protocol.request_message(request, 0) assert e.value.code == JSONRPC.INVALID_ARGS assert 'named arguments' in e.value.message # Requires an ID payload = {"method": "a", "params": [1, "foo"]} message = canonical_message(protocol, payload) with pytest.raises(ProtocolError) as e: protocol.message_to_item(message) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'no "id"' in e.value.message assert b'"id":null' in e.value.error_message def test_bad_requests(protocol_no_auto): protocol = protocol_no_auto payload = {"method": 2, "params": 3, "id": 0} with pytest.raises(ProtocolError) as e: payload_to_item(protocol, payload) assert e.value.code == JSONRPC.INVALID_ARGS assert 'invalid request arguments' in e.value.message assert b'"id":0' in e.value.error_message def test_good_requests(protocol_no_auto): protocol = protocol_no_auto payload = {"method": "", "id": -1} item, request_id = payload_to_item(protocol, payload) assert request_id == -1 assert_is_request(item, '', []) # recommended against in the spec, but valid payload = {"method": "", "id": None} item, request_id = payload_to_item(protocol, payload) assert request_id is None assert_is_notification(item, '', []) # recommended against in the spec, but valid payload = {"method": "", "id": 2.5} item, request_id = payload_to_item(protocol, payload) assert request_id == 2.5 assert_is_request(item, '', []) payload = {"method": "a", "id": 0} item, request_id = payload_to_item(protocol, payload) assert request_id == 0 assert_is_request(item, 'a', []) payload = {"method": "a", "params": [], "id": ""} item, request_id = payload_to_item(protocol, payload) assert request_id == "" assert_is_request(item, 'a', []) # Rest do not apply to JSONRPCv1; tested to fail elsewhere if protocol == JSONRPCv1: return payload = {"method": "a", "params": [1, "foo"]} item, request_id = payload_to_item(protocol, payload) assert_is_notification(item, 'a', [1, "foo"]) payload = {"method": "a", "params": {}} item, request_id = payload_to_item(protocol, payload) assert_is_notification(item, 'a', {}) payload = {"method": "a", "params": {"a": 1, "b": "c"}} item, request_id = payload_to_item(protocol, payload) assert_is_notification(item, 'a', {"a": 1, "b": "c"}) payload = {"method": "a", "params": {}, "id": 1} item, request_id = payload_to_item(protocol, payload) assert request_id == 1 assert_is_request(item, 'a', {}) payload = {"method": "a", "params": {"a": 1, "b": "c"}, "id": 1} item, request_id = payload_to_item(protocol, payload) assert request_id == 1 assert_is_request(item, 'a', {"a": 1, "b": "c"}) # # RESPONSES def test_response_bad(protocol_no_auto): protocol = protocol_no_auto # Missing ID payload = {"result": 2} with pytest.raises(ProtocolError) as e: payload_to_item(protocol, payload) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'no "id"' in e.value.message assert not e.value.error_message assert e.value.response_msg_id is None payload = {"error": {"code": 2, "message": "try harder"}} with pytest.raises(ProtocolError) as e: payload_to_item(protocol, payload) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'no "id"' in e.value.message assert not e.value.error_message assert e.value.response_msg_id is None # Result and error if protocol != JSONRPCv1: payload = {"result": 0, "error": {"code": 2, "message": ""}, "id": 0} with pytest.raises(ProtocolError) as e: payload_to_item(protocol, payload) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'both "result" and' in e.value.message assert e.value.response_msg_id == 0 assert not e.value.error_message payload = {"result": 1, "error": None, "id": 0} if protocol == JSONRPCLoose: payload_to_item(protocol, payload) else: with pytest.raises(ProtocolError) as e: payload_to_item(protocol, payload) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'both "result" and' in e.value.message assert e.value.response_msg_id == 0 assert not e.value.error_message # No result, also no error payload = {"foo": 1, "id": 1} with pytest.raises(ProtocolError) as e: payload_to_item(protocol, payload) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'neither "result" nor' in e.value.message assert e.value.response_msg_id == 1 assert not e.value.error_message # Bad ID payload = {"result": 2, "id": []} with pytest.raises(ProtocolError) as e: payload_to_item(protocol, payload) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'invalid "id"' in e.value.message assert e.value.response_msg_id is None assert not e.value.error_message def test_response_good(protocol_no_auto): protocol = protocol_no_auto # Integer payload = {"result": 2, "id": 1} item, request_id = payload_to_item(protocol, payload) assert request_id == 1 assert_is_good_response(item, 2) # Float payload = {"result": 2.1, "id": 1} item, request_id = payload_to_item(protocol, payload) assert_is_good_response(item, 2.1) # String payload = {"result": "f", "id": 1} item, request_id = payload_to_item(protocol, payload) assert_is_good_response(item, "f") # None payload = {"result": None, "id": 1} item, request_id = payload_to_item(protocol, payload) assert request_id == 1 assert_is_good_response(item, None) # Array payload = {"result": [1, 2], "id": 1} item, request_id = payload_to_item(protocol, payload) assert_is_good_response(item, [1, 2]) # Dictionary payload = {"result": {"a": 1}, "id": 1} item, request_id = payload_to_item(protocol, payload) assert_is_good_response(item, {"a": 1}) # Additional junk payload = {"result": 2, "id": 1, "junk": 0} item, request_id = payload_to_item(protocol, payload) assert_is_good_response(item, 2) def test_JSONRPCv2_response_error_bad(): payloads = [ {"error": 2, "id": 1}, {"error": "bar", "id": 1}, {"error": {"code": 1}, "id": 1}, {"error": {"message": "foo"}, "id": 1}, {"error": {"code": None, "message": "m"}, "id": 1}, {"error": {"code": 1, "message": None}, "id": 1}, {"error": {"code": "s", "message": "error"}, "id": 1}, {"error": {"code": 2, "message": 2}, "id": 1}, {"error": {"code": 2.5, "message": "bar"}, "id": 1}, ] protocol = JSONRPCv2 for payload in payloads: with pytest.raises(ProtocolError) as e: payload_to_item(protocol, payload) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'ill-formed' in e.value.message assert e.value.response_msg_id == 1 assert not e.value.error_message def test_JSONRPCLoose_responses(): protocol = JSONRPCLoose payload = {"result": 0, "error": None, "id": 1} item, request_id = payload_to_item(protocol, payload) assert request_id == 1 assert_is_good_response(item, 0) payload = {"result": None, "error": None, "id": 1} item, request_id = payload_to_item(protocol, payload) assert_is_good_response(item, None) payload = {"result": None, "error": 2, "id": 1} item, request_id = payload_to_item(protocol, payload) assert_is_error_response(item, 'no error message', 2) payload = {"result": 4, "error": 2, "id": 1} with pytest.raises(ProtocolError) as e: payload_to_item(protocol, payload) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'both' in e.value.message assert e.value.response_msg_id == 1 assert not e.value.error_message def test_JSONRPCv2_required_jsonrpc(): protocol = JSONRPCv2 payloads = [ {"error": {"code": 2, "message": "bar"}, "id": 1}, {"result": 1, "id": 2}, ] for payload in payloads: with pytest.raises(ProtocolError) as e: message = json.dumps(payload).encode() protocol.message_to_item(message) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'jsonrpc' in e.value.message assert not e.value.error_message payload = {"method": "f"} with pytest.raises(ProtocolError) as e: message = json.dumps(payload).encode() protocol.message_to_item(message) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'jsonrpc' in e.value.message # Respond to ill-formed "notification" assert b'"id":null' in e.value.error_message payload = {"method": "f", "id": 0} with pytest.raises(ProtocolError) as e: message = json.dumps(payload).encode() protocol.message_to_item(message) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'jsonrpc' in e.value.message assert b'jsonrpc' in e.value.error_message assert b'"id":0' in e.value.error_message def test_JSONRPCv1_errors(): protocol = JSONRPCv1 payloads = [ {"error": 2, "id": 1}, {"error": "bar", "id": 1}, {"error": {"code": 1}, "id": 1}, {"error": {"message": "foo"}, "id": 1}, {"error": {"code": None, "message": "m"}, "id": 1}, {"error": {"code": 1, "message": None}, "id": 1}, {"error": {"code": "s", "message": "error"}, "id": 1}, {"error": {"code": 2, "message": 2}, "id": 1}, {"error": {"code": 2.5, "message": "bar"}, "id": 1}, ] for payload in payloads: item, request_id = payload_to_item(protocol, payload) code = protocol.ERROR_CODE_UNAVAILABLE error = payload['error'] message = 'no error message provided' if isinstance(error, str): message = error elif isinstance(error, int): code = error elif isinstance(error, dict): if isinstance(error.get('message'), str): message = error['message'] if isinstance(error.get('code'), int): code = error['code'] assert request_id == 1 assert_is_error_response(item, message, code) payload = {"error": 2, "id": 1} with pytest.raises(ProtocolError) as e: protocol.message_to_item(json.dumps(payload).encode()) assert e.value.code == JSONRPC.INVALID_REQUEST assert '"result" and' in e.value.message assert not e.value.error_message payload = {"result": 4, "error": 2, "id": 1} with pytest.raises(ProtocolError) as e: protocol.message_to_item(json.dumps(payload).encode()) assert e.value.code == JSONRPC.INVALID_REQUEST assert '"result" and' in e.value.message assert not e.value.error_message def test_response_error_good(protocol_no_auto): protocol = protocol_no_auto payload = {"error": {"code": 5, "message": "bar"}, "id": 1} item, request_id = payload_to_item(protocol, payload) assert request_id == 1 assert_is_error_response(item, 'bar', 5) payload = {"error": {"code": 3, "message": "try again"}, "id": "a", "jnk": 0} item, request_id = payload_to_item(protocol, payload) assert request_id == "a" assert_is_error_response(item, 'again', 3) # BATCHES def test_batch_not_allowed(protocol): if not protocol.allow_batches: with pytest.raises(ProtocolError) as e: protocol.message_to_item(b'[]') assert e.value.code == JSONRPC.INVALID_REQUEST assert 'dictionary' in e.value.message assert b'"id":null' in e.value.error_message batch = Batch([Request('', [])]) with pytest.raises(ProtocolError) as e: protocol.batch_message(batch, {1}) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'permit batch' in e.value.message assert not e.value.error_message def test_empty_batch(): with pytest.raises(ProtocolError) as e: Batch([]) assert e.value.code == JSONRPC.INVALID_REQUEST assert 'empty' in e.value.message assert not e.value.error_message # Message contruction def test_batch_message_from_parts(protocol): with pytest.raises(ProtocolError) as e: protocol.batch_message_from_parts([]) assert 'empty' in e.value.message assert protocol.batch_message_from_parts([b'1']) == b'[1]' assert protocol.batch_message_from_parts([b'1', b'2']) == b'[1, 2]' # An empty part is not valid, but anyway. assert (protocol.batch_message_from_parts([b'1', b'', b'[3]']) == b'[1, , [3]]') def test_encode_payload(protocol): assert protocol.encode_payload(2) == b'2' assert protocol.encode_payload([2, 3]) == b'[2,3]' assert protocol.encode_payload({"a": 1}) == b'{"a":1}' assert protocol.encode_payload(True) == b'true' assert protocol.encode_payload(False) == b'false' assert protocol.encode_payload(None) == b'null' assert protocol.encode_payload("foo") == b'"foo"' with pytest.raises(ProtocolError) as e: protocol.encode_payload(b'foo') assert e.value.code == JSONRPC.INTERNAL_ERROR assert 'JSON' in e.value.message def test_JSONRPCv2_and_JSONRPCLoose_request_messages(): requests = [ (Request('foo', []), 2, {"jsonrpc": "2.0", "method": "foo", "id": 2}), (Request('foo', ()), 2, {"jsonrpc": "2.0", "method": "foo", "id": 2}), (Request('foo', {}), 2, {"jsonrpc": "2.0", "params": {}, "method": "foo", "id": 2}), (Request('foo', (1, 2)), 2, {"jsonrpc": "2.0", "method": "foo", "params": [1, 2], "id": 2}), (Request('foo', [1, 2]), 2, {"jsonrpc": "2.0", "method": "foo", "params": [1, 2], "id": 2}), (Request('foo', {"bar": 3, "baz": "bat"}), "it", {"jsonrpc": "2.0", "method": "foo", "params": {"bar": 3, "baz": "bat"}, "id": "it"}), ] notifications = [ (Notification('foo', []), {"jsonrpc": "2.0", "method": "foo"}), ] batches = [ (Batch([ Request('foo', []), Notification('bar', [2]), Request('baz', {'a': 1}), ]), [2, 3], [ {"jsonrpc": "2.0", "method": "foo", "id": 2}, {"jsonrpc": "2.0", "method": "bar", "params": [2]}, {"jsonrpc": "2.0", "method": "baz", "params": {'a': 1}, "id": 3}, ]), ] responses = [ ('foo', "it", {"jsonrpc": "2.0", "result": "foo", "id": "it"}), (2, "it", {"jsonrpc": "2.0", "result": 2, "id": "it"}), (None, -2, {"jsonrpc": "2.0", "result": None, "id": -2}), ([1, 2], -1, {"jsonrpc": "2.0", "result": [1, 2], "id": -1}), ({"kind": 1}, 0, {"jsonrpc": "2.0", "result": {"kind": 1}, "id": 0}), (RPCError(3, "j"), 1, {"jsonrpc": "2.0", "error": {"code": 3, "message": "j"}, "id": 1}), ] for protocol in [JSONRPCv2, JSONRPCLoose]: for item, request_id, payload in requests: binary = protocol.request_message(item, request_id) test_payload = json.loads(binary.decode()) assert test_payload == payload for item, payload in notifications: binary = protocol.notification_message(item) test_payload = json.loads(binary.decode()) assert test_payload == payload for result, request_id, payload in responses: binary = protocol.response_message(result, request_id) test_payload = json.loads(binary.decode()) assert test_payload == payload for batch, request_ids, payload in batches: binary = protocol.batch_message(batch, request_ids) test_payload = json.loads(binary.decode()) assert test_payload == payload def test_JSONRPCv1_messages(): requests = [ (Request('foo', []), 2, {"method": "foo", "params": [], "id": 2}), (Request('foo', [1, 2]), "s", {"method": "foo", "params": [1, 2], "id": "s"}), (Request('foo', [1, 2]), ["x"], {"method": "foo", "params": [1, 2], "id": ["x"]}), ] notifications = [ (Notification('foo', []), {"method": "foo", "params": [], "id": None}), ] responses = [ ('foo', "it", {"result": "foo", "error": None, "id": "it"}), (2, "it", {"result": 2, "error": None, "id": "it"}), (None, -2, {"result": None, "error": None, "id": -2}), ([1, 2], -1, {"result": [1, 2], "error": None, "id": -1}), ({"kind": 1}, [1], {"result": {"kind": 1}, "error": None, "id": [1]}), (RPCError(3, "j"), 1, {"result": None, "error": {"code": 3, "message": "j"}, "id": 1}), ] protocol = JSONRPCv1 for item, request_id, payload in requests: binary = protocol.request_message(item, request_id) test_payload = json.loads(binary.decode()) assert test_payload == payload for item, payload in notifications: binary = protocol.notification_message(item) test_payload = json.loads(binary.decode()) assert test_payload == payload for result, request_id, payload in responses: binary = protocol.response_message(result, request_id) test_payload = json.loads(binary.decode()) assert test_payload == payload with pytest.raises(TypeError): protocol.request_message(Request('foo', {}, 2)) with pytest.raises(TypeError): protocol.request_message(Request('foo', {"x": 1}, 2)) def test_protocol_detection(): bad_syntax_tests = [b'', b'\xf5', b'{"method":'] tests = [ (b'[]', JSONRPCLoose), (b'""', JSONRPCLoose), (b'{"jsonrpc": "2.0"}', JSONRPCv2), (b'{"jsonrpc": "1.0"}', JSONRPCv1), # No ID (b'{"method": "part"}', JSONRPCLoose), (b'{"error": 2}', JSONRPCLoose), (b'{"result": 3}', JSONRPCLoose), # Just ID (b'{"id": 2}', JSONRPCLoose), # Result or error alone (b'{"result": 3, "id":2}', JSONRPCLoose), (b'{"error": 3, "id":2}', JSONRPCLoose), (b'{"result": 3, "error": null, "id":2}', JSONRPCv1), # Method with or without params (b'{"method": "foo", "id": 1}', JSONRPCLoose), (b'{"method": "foo", "params": [], "id":2}', JSONRPCLoose), ] for message in bad_syntax_tests: with pytest.raises(ProtocolError): JSONRPCAutoDetect.detect_protocol(message) for message, answer in tests: result = JSONRPCAutoDetect.detect_protocol(message) assert answer == result test_by_answer = {} for message, answer in tests: test_by_answer[answer] = message # Batches. Test every combination... bm_from_parts = JSONRPC.batch_message_from_parts for length in range(1, len(test_by_answer)): for combo in combinations(test_by_answer, length): batch = bm_from_parts(test_by_answer[answer] for answer in combo) protocol = JSONRPCAutoDetect.detect_protocol(batch) if JSONRPCv2 in combo: assert protocol == JSONRPCv2 elif JSONRPCv1 in combo: assert protocol == JSONRPCv1 else: assert protocol == combo[0] # # Connection tests # @pytest.mark.asyncio async def test_send_request_and_response(protocol): '''Test sending a request gives the correct outgoing message, waits for a response, and returns it. Also raises if the response is an error. ''' req = Request('sum', [1, 2, 3]) connection = JSONRPCConnection(protocol) waiting = Event() send_message = None async def send_mess(): nonlocal send_message send_message, future = connection.send_request(req) waiting.set() assert await future == 6 # Test receipt of an error response send_message, future = connection.send_request(req) waiting.set() try: await future except Exception as e: assert_RPCError(e, JSONRPC.METHOD_NOT_FOUND, "cannot add up") send_message, future = connection.send_request(req) waiting.set() try: await future except Exception as e: assert_ProtocolError(e, JSONRPC.INVALID_REQUEST, '"result"') async def send_response(): for n in range(3): await waiting.wait() waiting.clear() assert connection.pending_requests() == [req] payload = json.loads(send_message.decode()) if protocol == JSONRPCv2: assert payload.get("jsonrpc") == "2.0" assert payload.get("method") == "sum" assert payload.get("params") == [1, 2, 3] if n == 0: message = protocol.response_message(6, payload["id"]) elif n == 1: error = RPCError(protocol.METHOD_NOT_FOUND, "cannot add up") message = protocol.response_message(error, payload["id"]) else: message = protocol.response_message(6, payload["id"]) message = message.replace(b'result', b'res') connection.receive_message(message) async with TaskGroup() as group: await group.spawn(send_mess) await group.spawn(send_response) assert not connection.pending_requests() @pytest.mark.asyncio async def test_receive_message_unmatched_response(protocol): '''Test receiving a response with an unmatchable request raises a ProtocolError to receive_message. ''' connection = JSONRPCConnection(protocol) message = protocol.response_message(1, 12345) with pytest.raises(ProtocolError) as e: await connection.receive_message(message) assert 'response to unsent request (ID: 12345)' in e.value.message message = protocol.response_message(1, None) with pytest.raises(ProtocolError) as e: await connection.receive_message(message) assert 'response to unsent request (ID: None)' in e.value.message error = RPCError(1, 'messed up') message = protocol.response_message(error, None) with pytest.raises(ProtocolError) as e: await connection.receive_message(message) assert 'diagnostic error received' in e.value.message assert 'messed up' in e.value.message @pytest.mark.asyncio async def test_send_response_round_trip(protocol): '''Test sending a request, receiving it, replying to it, and getting the response. ''' req = Request('sum', [1, 2, 3]) connection = JSONRPCConnection(protocol) queue = Queue() async def send_request(): message, future = connection.send_request(req) await queue.put(message) assert await future == 6 async def receive_request(): # This will be the request sent message = await queue.get() assert isinstance(message, bytes) assert connection.pending_requests() == [req] # Pretend we actually received this requests = connection.receive_message(message) assert requests == [req] # Send the result message = requests[0].send_result(6) # Receive the result requests = connection.receive_message(message) assert not requests async with timeout_after(0.2): async with TaskGroup() as group: await group.spawn(receive_request) await group.spawn(send_request) assert not connection.pending_requests() @pytest.mark.asyncio async def test_send_batch_round_trip(batch_protocol): '''Test sending a batch (with both Requests and Notifications), receiving it, replying to it in a random order, and getting the response in the correct order. ''' protocol = batch_protocol items = [Request('echo', [n]) for n in range(15)] answers = [n for n in range(len(items))] # Replace a couple of answers with errors and throw in some notifications for pos in range(0, len(answers), 4): answers[pos] = RPCError(pos, 'division by zero') items.insert(pos, Notification('n', [pos])) batch = Batch(items) connection = JSONRPCConnection(protocol) queue = Queue() async def send_request(): # Check the returned answers are in the correct order message, future = connection.send_batch(batch) await queue.put(message) assert await future == tuple(answers) async def receive_request(): # This will be the batch request sent message = await queue.get() assert connection.pending_requests() == [batch] # Pretend we actually received this requests = connection.receive_message(message) # Check we get the requests separately answer_iter = iter(answers) req_ans = [] for request, req in zip(requests, batch): assert request == req if isinstance(request, Request): req_ans.append((request, next(answer_iter))) # Send the responses in a random order shuffle(req_ans) for request, answer in req_ans: message = request.send_result(answer) if message: assert not connection.receive_message(message) assert not connection.pending_requests() else: assert connection.pending_requests() async with TaskGroup() as group: await group.spawn(receive_request) await group.spawn(send_request) assert not connection.pending_requests() @pytest.mark.asyncio async def test_send_notification_batch(batch_protocol): '''Test that a notification batch does not wait for a response.''' protocol = batch_protocol batch = Batch([Notification('n', [n]) for n in range(10)]) connection = JSONRPCConnection(protocol) queue = Queue() async def send_request(): message, event = connection.send_batch(batch) assert not connection.pending_requests() await queue.put(message) assert event is None async def receive_request(): # This will be the batch request sent message = await queue.get() # Pretend we actually received this requests = connection.receive_message(message) # Check we get the requests separately for req, request in zip(batch, requests): assert req == request async with timeout_after(0.2): async with TaskGroup() as group: await group.spawn(receive_request) await group.spawn(send_request) assert not connection.pending_requests() @pytest.mark.asyncio async def test_batch_fails(batch_protocol): '''Test various failure cases for batches.''' protocol = batch_protocol batch = Batch([ Request('test', [1, 2, 3]), ]) connection = JSONRPCConnection(protocol) queue = Queue() async def send_request(): message, future = connection.send_batch(batch) await queue.put(message) async with ignore_after(0.01): await future async def receive_request(): # This will be the batch request sent message = await queue.get() assert connection.pending_requests() == [batch] # Send a batch response we didn't get parts = [protocol.response_message(2, "bad_id")] fake_message = protocol.batch_message_from_parts(parts) with pytest.raises(ProtocolError) as e: connection.receive_message(fake_message) assert 'response to unsent batch' in e.value.message assert connection.pending_requests() == [batch] # Send a batch with a duplicate response data = json.loads(message.decode()) parts = [protocol.response_message(2, data[0]['id'])] * 2 fake_message = protocol.batch_message_from_parts(parts) with pytest.raises(ProtocolError) as e: await connection.receive_message(fake_message) assert 'response to unsent batch' in e.value.message async with TaskGroup() as group: await group.spawn(receive_request) await group.spawn(send_request) assert connection.pending_requests() == [batch] @pytest.mark.asyncio async def test_send_notification(protocol): '''Test sending a notification doesn't wait.''' req = Notification('wakey', []) connection = JSONRPCConnection(protocol) queue = Queue() async def send_request(): message = connection.send_notification(req) assert isinstance(message, bytes) await queue.put(message) assert not connection.pending_requests() async def receive_request(): # This will be the notification sent message = await queue.get() assert not connection.pending_requests() # Pretend we actually received this requests = connection.receive_message(message) assert requests == [req] async with timeout_after(0.2): async with TaskGroup() as group: await group.spawn(receive_request) await group.spawn(send_request) assert not connection.pending_requests() @pytest.mark.asyncio async def test_max_response_size(protocol): request = Request('', []) result = "a" size = len(protocol.response_message(result, 0)) queue = Queue() JSONRPCConnection._id_counter = count() async def send_request_good(request): message, future = connection.send_request(request) await queue.put(message) assert await future == result async def send_request_bad(request): message, future = connection.send_request(request) await queue.put(message) try: await future assert False except Exception as e: assert_RPCError(e, JSONRPC.INVALID_REQUEST, "response too large") async def receive_request(count): # This will be the notification sent message = await queue.get() # Pretend we actually received this requests = connection.receive_message(message) for req in requests: message = req.send_result(result) # Receive the result if message: assert not connection.receive_message(message) connection = JSONRPCConnection(protocol) connection.max_response_size = size async with TaskGroup() as group: await group.spawn(receive_request(1)) await group.spawn(send_request_good(request)) connection.max_response_size = size - 1 async with TaskGroup() as group: await group.spawn(receive_request(1)) await group.spawn(send_request_bad(request)) async def send_batch(batch): message, future = connection.send_batch(batch) await queue.put(message) results = await future for n, part_result in enumerate(results): if n == 0: assert part_result == result else: assert "too large" in part_result.message if protocol.allow_batches: connection.max_response_size = size + 3 batch = Batch([request, request, request]) async with TaskGroup() as group: await group.spawn(receive_request(len(batch))) await group.spawn(send_batch(batch)) def test_misc(protocol): '''Misc tests to get full coverage.''' connection = JSONRPCConnection(protocol) with pytest.raises(ProtocolError): connection.receive_message(b'[]') with pytest.raises(AssertionError): connection.send_request(Response(2)) request = Request('a', []) assert request.send_result(2) is None def test_handler_invocation(): # Peculiar function signatures # pow - Built-in; 2 positional args, 1 optional 3rd named arg powb = pow def add_3(x, y, z=0): return x + y + z def add_many(first, second=0, *values): values += (first, second) return sum(values) def echo_2(first, *, second=2): return [first, second] def kwargs(start, *kwargs): return start + len(kwargs) def both(start=2, *args, **kwargs): return start + len(args) * 10 + len(kwargs) * 4 good_requests = ( (Request('add_3', (1, 2, 3)), 6), (Request('add_3', [5, 7]), 12), (Request('add_3', {'x': 5, 'y': 7}), 12), (Request('add_3', {'x': 5, 'y': 7, 'z': 3}), 15), (Request('add_many', [1]), 1), (Request('add_many', [5, 50, 500]), 555), (Request('add_many', list(range(10))), 45), (Request('add_many', {'first': 1}), 1), (Request('add_many', {'first': 1, 'second': 10}), 11), (Request('powb', [2, 3]), 8), (Request('powb', [2, 3, 5]), 3), (Request('echo_2', ['ping']), ['ping', 2]), (Request('echo_2', {'first': 1, 'second': 8}), [1, 8]), (Request('kwargs', [1]), 1), (Request('kwargs', [1, 2]), 2), (Request('kwargs', {'start': 3}), 3), (Request('both', []), 2), (Request('both', [1]), 1), (Request('both', [5, 2]), 15), (Request('both', {'end': 4}), 6), (Request('both', {'start': 3}), 3), (Request('both', {'start': 3, 'end': 1, '3rd': 1}), 11), ) for request, result in good_requests: handler = locals()[request.method] invocation = handler_invocation(handler, request) assert invocation() == result if sys.version_info < (3, 8): powb_request = (Request('powb', {"x": 2, "y": 3}), 'cannot be called') else: powb_request = (Request('powb', {"x": 2, "y": 3}), 'requires parameters') bad_requests = [ (Request('missing_method', []), 'unknown method'), (Request('add_many', []), 'requires 1'), (Request('add_many', {'first': 1, 'values': []}), 'values'), powb_request, (Request('echo_2', ['ping', 'pong']), 'at most 1'), (Request('echo_2', {'first': 1, 'second': 8, '3rd': 1}), '3rd'), (Request('kwargs', []), 'requires 1'), (Request('kwargs', {'end': 4}), "start"), (Request('kwargs', {'start': 3, 'end': 1, '3rd': 1}), '3rd'), ] for request, text in bad_requests: with pytest.raises(RPCError) as e: handler = locals().get(request.method) handler_invocation(handler, request) assert text in e.value.message aiorpcX-0.24/tests/test_session.py000077500000000000000000001166601474217261100173240ustar00rootroot00000000000000import asyncio import json import logging import sys import time from contextlib import suppress import pytest from aiorpcx import ( MessageSession, ProtocolError, sleep, spawn, TaskGroup, BitcoinFramer, SOCKS5, SOCKSProxy, connect_rs, JSONRPC, Batch, RPCError, TaskTimeout, RPCSession, timeout_after, serve_rs, NewlineFramer, BatchError, ExcessiveSessionCostError, SessionKind, ReplyAndDisconnect, ignore_after, handler_invocation, CancelledError, ) from aiorpcx.session import Concurrency from util import RaiseTest if sys.version_info >= (3, 7): from asyncio import all_tasks else: from asyncio import Task all_tasks = Task.all_tasks def raises_method_not_found(message): return RaiseTest(JSONRPC.METHOD_NOT_FOUND, message, RPCError) class MyServerSession(RPCSession): sessions = [] def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) self.notifications = [] self.sessions.append(self) assert self.session_kind == SessionKind.SERVER @classmethod async def current_server(self): await sleep(0.05) return self.sessions[0] async def connection_lost(self): await super().connection_lost() self.sessions.remove(self) async def handle_request(self, request): handler = getattr(self, f'on_{request.method}', None) invocation = handler_invocation(handler, request) return await invocation() async def on_send_bad_response(self, response): message = json.dumps(response).encode() await self._send_message(message) async def on_echo(self, value): return value async def on_notify(self, thing): self.notifications.append(thing) async def on_bug(self): raise ValueError async def on_costly_error(self, cost): raise RPCError(1, "that cost a bunch!", cost=cost) async def on_disconnect(self, result=RPCError): if result is RPCError: raise ReplyAndDisconnect(RPCError(1, 'incompatible version')) raise ReplyAndDisconnect(result) async def on_sleepy(self): await sleep(10) def in_caplog(caplog, message): return any(message in record.message for record in caplog.records) def caplog_count(caplog, message): return sum(message in record.message for record in caplog.records) @pytest.fixture def server_port(unused_tcp_port, event_loop): coro = serve_rs(MyServerSession, 'localhost', unused_tcp_port, loop=event_loop) server = event_loop.run_until_complete(coro) yield unused_tcp_port if hasattr(asyncio, 'all_tasks'): tasks = asyncio.all_tasks(event_loop) else: tasks = asyncio.Task.all_tasks(loop=event_loop) async def close_all(): server.close() await server.wait_closed() if tasks: await asyncio.wait(tasks) event_loop.run_until_complete(close_all()) class TestRPCSession: @pytest.mark.asyncio async def test_no_proxy(self, server_port): proxy = SOCKSProxy('localhost:79', SOCKS5, None) with pytest.raises(OSError): async with connect_rs('localhost', server_port, proxy=proxy): pass @pytest.mark.asyncio async def test_handlers(self, server_port): async with timeout_after(0.1): async with connect_rs('localhost', server_port) as session: assert session.session_kind == SessionKind.CLIENT assert session.proxy() is None with raises_method_not_found('something'): await session.send_request('something') await session.send_notification('something') assert session.is_closing() @pytest.mark.asyncio async def test_send_request(self, server_port): async with connect_rs('localhost', server_port) as session: assert await session.send_request('echo', [23]) == 23 assert session.transport._closed_event.is_set() assert session.transport._process_messages_task.done() @pytest.mark.asyncio async def test_send_request_buggy_handler(self, server_port): async with connect_rs('localhost', server_port) as session: with RaiseTest(JSONRPC.INTERNAL_ERROR, 'internal server error', RPCError): await session.send_request('bug') @pytest.mark.asyncio async def test_unexpected_response(self, server_port, caplog): async with connect_rs('localhost', server_port) as session: # A request not a notification so we don't exit immediately response = {"jsonrpc": "2.0", "result": 2, "id": -1} with caplog.at_level(logging.DEBUG): await session.send_request('send_bad_response', (response, )) assert in_caplog(caplog, 'unsent request') @pytest.mark.asyncio async def test_unanswered_request_count(self, server_port): async with connect_rs('localhost', server_port) as session: server_session = await MyServerSession.current_server() assert session.unanswered_request_count() == 0 assert server_session.unanswered_request_count() == 0 async with ignore_after(0.01): await session.send_request('sleepy') assert session.unanswered_request_count() == 0 assert server_session.unanswered_request_count() == 1 @pytest.mark.asyncio async def test_send_request_bad_args(self, server_port): async with connect_rs('localhost', server_port) as session: # ProtocolError as it's a protocol violation with RaiseTest(JSONRPC.INVALID_ARGS, 'list', ProtocolError): await session.send_request('echo', "23") @pytest.mark.asyncio async def test_send_request_timeout0(self, server_port): async with connect_rs('localhost', server_port) as session: with pytest.raises(TaskTimeout): async with timeout_after(0): await session.send_request('echo', [23]) @pytest.mark.asyncio async def test_send_request_timeout(self, server_port): async with connect_rs('localhost', server_port) as session: server_session = await MyServerSession.current_server() with pytest.raises(TaskTimeout): async with timeout_after(0.01): await session.send_request('sleepy') # Assert the server doesn't treat cancellation as an error assert server_session.errors == 0 @pytest.mark.asyncio async def test_error_base_cost(self, server_port): async with connect_rs('localhost', server_port) as session: server_session = await MyServerSession.current_server() server_session.error_base_cost = server_session.cost_hard_limit * 1.1 await session._send_message(b'') await sleep(0.05) assert server_session.errors == 1 assert server_session.cost > server_session.cost_hard_limit # Check next request raises and cuts us off with pytest.raises(RPCError): await session.send_request('echo', [23]) await sleep(0.02) assert session.is_closing() @pytest.mark.asyncio async def test_RPCError_cost(self, server_port): async with connect_rs('localhost', server_port) as session: server_session = await MyServerSession.current_server() err = RPCError(0, 'message') assert err.cost == 0 with pytest.raises(RPCError): await session.send_request('costly_error', [1000]) # It can trigger a cost recalc which refunds a tad epsilon = 1 assert server_session.cost > server_session.error_base_cost + 1000 - epsilon @pytest.mark.asyncio async def test_send_notification(self, server_port): async with connect_rs('localhost', server_port) as session: server = await MyServerSession.current_server() await session.send_notification('notify', ['test']) await sleep(0.001) assert server.notifications == ['test'] @pytest.mark.asyncio async def test_force_close(self, server_port): async with connect_rs('localhost', server_port) as session: assert not session.transport._closed_event.is_set() await session.close(force_after=0.001) assert session.transport._closed_event.is_set() @pytest.mark.asyncio async def test_force_close_abort_codepath(self, server_port): async with connect_rs('localhost', server_port) as session: protocol = session.transport assert not protocol._closed_event.is_set() await session.close(force_after=0) assert protocol._closed_event.is_set() @pytest.mark.asyncio async def test_verbose_logging(self, server_port, caplog): async with connect_rs('localhost', server_port) as session: session.verbosity = 4 with caplog.at_level(logging.DEBUG): await session.send_request('echo', ['wait']) assert in_caplog(caplog, "sending message b") assert in_caplog(caplog, "received data b") @pytest.mark.asyncio async def test_framer_MemoryError(self, server_port, caplog): async with connect_rs('localhost', server_port, framer=NewlineFramer(5)) as session: msg = 'w' * 50 raw_msg = msg.encode() # Even though long it will be sent in one bit request = session.send_request('echo', [msg]) assert await request == msg assert not caplog.records session.transport.data_received(raw_msg) # Unframed; no \n await sleep(0) assert len(caplog.records) == 1 assert in_caplog(caplog, 'dropping message over 5 bytes') # @pytest.mark.asyncio # async def test_resource_release(self, server_port): # loop = asyncio.get_event_loop() # tasks = all_tasks(loop) # try: # session = connect_rs('localhost', 0) # await session.create_connection() # except OSError: # pass # assert all_tasks(loop) == tasks # async with connect_rs('localhost', server_port): # pass # await asyncio.sleep(0.01) # Let things be processed # assert all_tasks(loop) == tasks @pytest.mark.asyncio async def test_pausing(self, server_port): called = [] limit = None def my_write(data): called.append(data) if len(called) == limit: session.transport.pause_writing() async with connect_rs('localhost', server_port) as session: protocol = session.transport assert protocol._can_send.is_set() asyncio_transport = protocol._asyncio_transport try: asyncio_transport.write = my_write except AttributeError: # uvloop: transport.write is read-only return await session._send_message(b'a') assert protocol._can_send.is_set() assert called called.clear() async def monitor(): await sleep(0.002) assert called == [b'A\n', b'very\n'] assert not protocol._can_send.is_set() protocol.resume_writing() assert protocol._can_send.is_set() limit = 2 msgs = b'A very long and boring meessage'.split() task = await spawn(monitor) for msg in msgs: await session._send_message(msg) assert called == [session.transport._framer.frame(msg) for msg in msgs] limit = None # Check idempotent protocol.resume_writing() assert task.result() is None @pytest.mark.asyncio async def test_slow_connection_aborted(self, server_port): async with connect_rs('localhost', server_port) as session: protocol = session.transport assert session.max_send_delay >= 10 session.max_send_delay = 0.004 protocol.pause_writing() assert not protocol._can_send.is_set() task = await spawn(session._send_message(b'a')) await sleep(0.1) assert isinstance(task.exception(), TaskTimeout) assert protocol._can_send.is_set() assert session.is_closing() @pytest.mark.asyncio async def test_concurrency(self, server_port): async with connect_rs('localhost', server_port) as session: # By default clients don't have a hard limit assert session.cost_hard_limit == 0 session.cost_hard_limit = session.cost_soft_limit * 2 # Prevent this interfering session.cost_decay_per_sec = 0 # Test usage below soft limit session.cost = session.cost_soft_limit - 10 session.recalc_concurrency() assert session._incoming_concurrency.max_concurrent == session.initial_concurrent assert session._cost_fraction == 0.0 # Test usage at soft limit doesn't affect concurrency session.cost = session.cost_soft_limit session.recalc_concurrency() assert session._incoming_concurrency.max_concurrent == session.initial_concurrent assert session._cost_fraction == 0.0 # Test usage half-way session.cost = (session.cost_soft_limit + session.cost_hard_limit) // 2 session.recalc_concurrency() assert 1 < session._incoming_concurrency.max_concurrent < session.initial_concurrent assert 0.49 < session._cost_fraction < 0.51 # Test at hard limit session.cost = session.cost_hard_limit session.recalc_concurrency() assert session._cost_fraction == 1.0 # Test above hard limit disconnects session.cost = session.cost_hard_limit + 1 session.recalc_concurrency() with pytest.raises(ExcessiveSessionCostError): async with session._incoming_concurrency: pass @pytest.mark.asyncio async def test_concurrency_no_limit_for_outgoing(self, server_port): async with connect_rs('localhost', server_port) as session: # Prevent this interfering session.cost_decay_per_sec = 0 # Test usage half-way session.cost = (RPCSession.cost_soft_limit + RPCSession.cost_hard_limit) // 2 session.recalc_concurrency() assert session._incoming_concurrency.max_concurrent == session.initial_concurrent assert session._cost_fraction == 0 # Test above hard limit does not disconnect session.cost = RPCSession.cost_hard_limit + 1 session.recalc_concurrency() async with session._incoming_concurrency: pass @pytest.mark.asyncio async def test_concurrency_decay(self, server_port): async with connect_rs('localhost', server_port) as session: session.cost_decay_per_sec = 100 session.cost = 1000 await sleep(0.1) session.recalc_concurrency() assert 970 < session.cost < 992 @pytest.mark.asyncio async def test_concurrency_hard_limit_0(self, server_port): async with connect_rs('localhost', server_port) as session: session.cost = 1_000_000_000 session.cost_hard_limit = 0 session.recalc_concurrency() assert session._incoming_concurrency.max_concurrent == session.initial_concurrent @pytest.mark.asyncio async def test_extra_cost(self, server_port): async with connect_rs('localhost', server_port) as session: # By default clients don't have a hard limit assert session.cost_hard_limit == 0 session.cost_hard_limit = session.cost_soft_limit * 2 session.extra_cost = lambda: session.cost_soft_limit + 1 session.recalc_concurrency() assert 1 > session._cost_fraction > 0 session.extra_cost = lambda: session.cost_hard_limit + 1 session.recalc_concurrency() assert session._cost_fraction > 1 @pytest.mark.asyncio async def test_request_over_hard_limit(self, server_port): async with connect_rs('localhost', server_port) as session: server = await MyServerSession.current_server() server.bump_cost(server.cost_hard_limit + 100) async with timeout_after(0.1): with pytest.raises(RPCError) as e: await session.send_request('echo', [23]) assert 'excessive resource usage' in str(e.value) @pytest.mark.asyncio async def test_request_sleep(self, server_port): async with connect_rs('localhost', server_port) as session: server = await MyServerSession.current_server() server.bump_cost((server.cost_soft_limit + server.cost_hard_limit) / 2) server.cost_sleep = 0.1 t1 = time.time() await session.send_request('echo', [23]) t2 = time.time() assert t2 - t1 > (server.cost_sleep / 2) * 0.9 # Fudge factor for Linux @pytest.mark.asyncio async def test_server_busy(self, server_port): async with connect_rs('localhost', server_port) as session: server = await MyServerSession.current_server() server.processing_timeout = 0.01 with pytest.raises(RPCError) as e: await session.send_request('sleepy') assert 'server busy' in str(e.value) assert server.errors == 1 @pytest.mark.asyncio async def test_reply_and_disconnect_value(self, server_port): async with connect_rs('localhost', server_port) as session: value = 42 assert await session.send_request('disconnect', [value]) == value await sleep(0.01) assert session.is_closing() @pytest.mark.asyncio async def test_reply_and_disconnect_error(self, server_port): async with connect_rs('localhost', server_port) as session: with pytest.raises(RPCError) as e: assert await session.send_request('disconnect') await sleep(0.01) exc = e.value assert exc.code == 1 and exc.message == 'incompatible version' assert session.is_closing() @pytest.mark.asyncio async def test_send_empty_batch(self, server_port): async with connect_rs('localhost', server_port) as session: with RaiseTest(JSONRPC.INVALID_REQUEST, 'empty', ProtocolError): async with session.send_batch() as batch: pass assert len(batch) == 0 assert batch.batch is None assert batch.results is None @pytest.mark.asyncio async def test_send_batch(self, server_port): async with connect_rs('localhost', server_port) as session: async with session.send_batch() as batch: batch.add_request("echo", [1]) batch.add_notification("echo", [2]) batch.add_request("echo", [3]) assert isinstance(batch.batch, Batch) assert len(batch) == 3 assert isinstance(batch.results, tuple) assert len(batch.results) == 2 assert batch.results == (1, 3) @pytest.mark.asyncio async def test_send_batch_errors_quiet(self, server_port): async with connect_rs('localhost', server_port) as session: async with session.send_batch() as batch: batch.add_request("echo", [1]) batch.add_request("bug") assert isinstance(batch.batch, Batch) assert len(batch) == 2 assert isinstance(batch.results, tuple) assert len(batch.results) == 2 assert isinstance(batch.results[1], RPCError) @pytest.mark.asyncio async def test_send_batch_errors(self, server_port): async with connect_rs('localhost', server_port) as session: with pytest.raises(BatchError) as e: async with session.send_batch(raise_errors=True) as batch: batch.add_request("echo", [1]) batch.add_request("bug") assert e.value.request is batch assert isinstance(batch.batch, Batch) assert len(batch) == 2 assert isinstance(batch.results, tuple) assert len(batch.results) == 2 assert isinstance(batch.results[1], RPCError) @pytest.mark.asyncio async def test_send_batch_cancelled(self, server_port): async with connect_rs('localhost', server_port) as session: async def send_batch(): async with session.send_batch(raise_errors=True) as batch: batch.add_request('sleepy') task = await spawn(send_batch) await session.close() await asyncio.wait([task]) assert task.cancelled() @pytest.mark.asyncio async def test_send_batch_bad_request(self, server_port): async with connect_rs('localhost', server_port) as session: with RaiseTest(JSONRPC.METHOD_NOT_FOUND, 'string', ProtocolError): async with session.send_batch() as batch: batch.add_request(23) @pytest.mark.asyncio async def test_send_request_throttling(self, server_port): async with connect_rs('localhost', server_port) as session: N = 3 session.recalibrate_count = N prior = session._outgoing_concurrency.max_concurrent async with TaskGroup() as group: for n in range(N): await group.spawn(session.send_request("echo", ["ping"])) current = session._outgoing_concurrency.max_concurrent assert prior * 1.2 > current > prior @pytest.mark.asyncio async def test_send_batch_throttling(self, server_port): async with connect_rs('localhost', server_port) as session: N = 3 session.recalibrate_count = N prior = session._outgoing_concurrency.max_concurrent async with session.send_batch() as batch: for n in range(N): batch.add_request("echo", ["ping"]) current = session._outgoing_concurrency.max_concurrent assert prior * 1.2 > current > prior @pytest.mark.asyncio async def test_sent_request_timeout(self, server_port): async with connect_rs('localhost', server_port) as session: session.sent_request_timeout = 0.01 start = time.time() with pytest.raises(TaskTimeout): await session.send_request('sleepy') assert time.time() - start < 0.1 @pytest.mark.asyncio async def test_log_me(self, server_port, caplog): async with connect_rs('localhost', server_port) as session: server = await MyServerSession.current_server() with caplog.at_level(logging.INFO): assert server.log_me is False await session.send_request('echo', ['ping']) assert caplog_count(caplog, '"method":"echo"') == 0 server.log_me = True await session.send_request('echo', ['ping']) assert caplog_count(caplog, '"method":"echo"') == 1 class WireRPCSession(RPCSession): # For tests of wire messages async def send(self, item): if not isinstance(item, str): item = json.dumps(item) item = item.encode() await self._send_message(item) async def response(self): message = await self.transport.receive_message() return json.loads(message.decode()) def connect_wire_session(host, port): return connect_rs(host, port, session_factory=WireRPCSession) class TestWireResponses(object): # These tests are similar to those in the JSON RPC v2 specification @pytest.mark.asyncio async def test_send_request(self, server_port): async with connect_wire_session('localhost', server_port) as session: item = {"jsonrpc": "2.0", "method": "echo", "params": [[42, 43]], "id": 1} await session.send(item) assert await session.response() == {"jsonrpc": "2.0", "result": [42, 43], "id": 1} @pytest.mark.asyncio async def test_send_request_named(self, server_port): async with connect_wire_session('localhost', server_port) as session: item = {"jsonrpc": "2.0", "method": "echo", "params": {"value": [42, 43]}, "id": 3} await session.send(item) assert await session.response() == {"jsonrpc": "2.0", "result": [42, 43], "id": 3} @pytest.mark.asyncio async def test_send_notification(self, server_port): async with connect_wire_session('localhost', server_port) as session: item = {"jsonrpc": "2.0", "method": "echo", "params": [[42, 43]]} await session.send(item) with pytest.raises(TaskTimeout): async with timeout_after(0.002): await session.response() @pytest.mark.asyncio async def test_send_non_existent_notification(self, server_port): async with connect_wire_session('localhost', server_port) as session: item = {"jsonrpc": "2.0", "method": "zz", "params": [[42, 43]]} await session.send(item) with pytest.raises(TaskTimeout): async with timeout_after(0.002): await session.response() @pytest.mark.asyncio async def test_send_non_existent_method(self, server_port): async with connect_wire_session('localhost', server_port) as session: item = {"jsonrpc": "2.0", "method": "foobar", "id": 0} await session.send(item) assert await session.response() == { "jsonrpc": "2.0", "id": 0, "error": {'code': -32601, 'message': 'unknown method "foobar"'}} @pytest.mark.asyncio async def test_send_invalid_json(self, server_port): async with connect_wire_session('localhost', server_port) as session: item = '{"jsonrpc": "2.0", "method": "foobar, "params": "bar", "b]' await session.send(item) assert await session.response() == { "jsonrpc": "2.0", "error": {"code": -32700, "message": "invalid JSON"}, "id": None} @pytest.mark.asyncio async def test_send_invalid_request_object(self, server_port): async with connect_wire_session('localhost', server_port) as session: item = {"jsonrpc": "2.0", "method": 1, "params": "bar"} await session.send(item) assert await session.response() == { "jsonrpc": "2.0", "id": None, "error": {"code": -32602, "message": "invalid request arguments: bar"}} @pytest.mark.asyncio async def test_send_batch_invalid_json(self, server_port): async with connect_wire_session('localhost', server_port) as session: item = ('[{"jsonrpc": "2.0", "method": "sum", "params": [1,2,4],' '"id": "1"}, {"jsonrpc": "2.0", "method" ]') await session.send(item) assert await session.response() == { "jsonrpc": "2.0", "id": None, "error": {"code": -32700, "message": "invalid JSON"}} @pytest.mark.asyncio async def test_send_empty_batch(self, server_port): async with connect_wire_session('localhost', server_port) as session: item = [] await session.send(item) assert await session.response() == { "jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "batch is empty"}} @pytest.mark.asyncio async def test_send_invalid_batch(self, server_port): async with connect_wire_session('localhost', server_port) as session: item = [1] await session.send(item) assert await session.response() == [{ "jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "request object must be a dictionary"}}] @pytest.mark.asyncio async def test_send_invalid_batch_3(self, server_port): async with connect_wire_session('localhost', server_port) as session: item = [1, 2, 3] await session.send(item) assert await session.response() == [{ "jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "request object must be a dictionary"}}] * 3 @pytest.mark.asyncio async def test_send_partly_invalid_batch(self, server_port): async with connect_wire_session('localhost', server_port) as session: item = [1, {"jsonrpc": "2.0", "method": "echo", "params": [42], "id": 0}] await session.send(item) assert await session.response() == [ {"jsonrpc": "2.0", "id": None, "error": {"code": -32600, "message": "request object must be a dictionary"}}, {"jsonrpc": "2.0", "result": 42, "id": 0}] @pytest.mark.asyncio async def test_send_mixed_batch(self, server_port): async with connect_wire_session('localhost', server_port) as session: item = [ {"jsonrpc": "2.0", "method": "echo", "params": [40], "id": 3}, {"jsonrpc": "2.0", "method": "echo", "params": [42]}, {"jsonrpc": "2.0", "method": "echo", "params": [41], "id": 2} ] await session.send(item) assert await session.response() == [ {"jsonrpc": "2.0", "result": 40, "id": 3}, {"jsonrpc": "2.0", "result": 41, "id": 2} ] @pytest.mark.asyncio async def test_send_notification_batch(self, server_port): async with connect_wire_session('localhost', server_port) as session: item = [{"jsonrpc": "2.0", "method": "echo", "params": [42]}] * 2 await session.send(item) with pytest.raises(TaskTimeout): async with timeout_after(0.002): assert await session.response() class MessageServer(MessageSession): def __init__(self, *args, **kwargs): super().__init__(*args, **kwargs) MessageServer._current_session = self self.messages = [] @classmethod async def current_server(self): while True: await sleep(0.001) if self._current_session: return self._current_session async def connection_lost(self): await super().connection_lost() MessageServer._current_session = None async def handle_message(self, message): command, payload = message self.messages.append(message) if command == b'syntax': raise SyntaxError elif command == b'protocol': raise ProtocolError(2, 'Not allowed') elif command == b'cancel': raise CancelledError elif command == b'sleep': await sleep(0.2) @pytest.fixture def msg_server_port(event_loop, unused_tcp_port): coro = serve_rs(MessageServer, 'localhost', unused_tcp_port, loop=event_loop) server = event_loop.run_until_complete(coro) yield unused_tcp_port if hasattr(asyncio, 'all_tasks'): tasks = asyncio.all_tasks(event_loop) else: tasks = asyncio.Task.all_tasks(loop=event_loop) async def close_all(): server.close() await server.wait_closed() if tasks: await asyncio.wait(tasks) event_loop.run_until_complete(close_all()) def connect_message_session(host, port, proxy=None, framer=None): return connect_rs(host, port, proxy=proxy, framer=framer, session_factory=MessageSession) class TestMessageSession(object): @pytest.mark.asyncio async def test_basic_send(self, msg_server_port): async with connect_message_session('localhost', msg_server_port) as session: server_session = await MessageServer.current_server() await session.send_message((b'version', b'abc')) await sleep(0.02) assert server_session.messages == [(b'version', b'abc')] @pytest.mark.asyncio async def test_many_sends(self, msg_server_port): count = 12 async with connect_message_session('localhost', msg_server_port) as session: server_session = await MessageServer.current_server() for n in range(count): await session.send_message((b'version', b'abc')) assert server_session.messages == [(b'version', b'abc')] * count @pytest.mark.asyncio async def test_errors(self, msg_server_port, caplog): async with connect_message_session('localhost', msg_server_port) as session: await session.send_message((b'syntax', b'')) await session.send_message((b'protocol', b'')) await session.send_message((b'cancel', b'')) assert in_caplog(caplog, 'exception handling') assert in_caplog(caplog, 'Not allowed') @pytest.mark.asyncio async def test_bad_magic(self, msg_server_port, caplog): framer = BitcoinFramer(magic=bytes(4)) async with connect_message_session('localhost', msg_server_port, framer=framer) as session: await session.send_message((b'version', b'')) await sleep(0.01) assert in_caplog(caplog, 'bad network magic') @pytest.mark.asyncio async def test_bad_checksum(self, msg_server_port, caplog): framer = BitcoinFramer() framer._checksum = lambda payload: bytes(32) async with connect_message_session('localhost', msg_server_port, framer=framer) as session: await session.send_message((b'version', b'')) assert in_caplog(caplog, 'checksum mismatch') @pytest.mark.asyncio async def test_oversized_message(self, msg_server_port, caplog): big = BitcoinFramer.max_payload_size async with connect_message_session('localhost', msg_server_port) as session: await session.send_message((b'version', bytes(big))) assert not in_caplog(caplog, 'oversized payload') async with connect_message_session('localhost', msg_server_port) as session: await session.send_message((b'version', bytes(big + 1))) assert in_caplog(caplog, 'oversized payload') @pytest.mark.asyncio async def test_proxy(self, msg_server_port): proxy = SOCKSProxy('localhost:79', SOCKS5, None) with pytest.raises(OSError): async with connect_message_session('localhost', msg_server_port, proxy=proxy): pass @pytest.mark.asyncio async def test_request_sleeps(self, msg_server_port, caplog): async with connect_message_session('localhost', msg_server_port) as session: server = await MessageServer.current_server() server.bump_cost((server.cost_soft_limit + server.cost_hard_limit) / 2) # Messaging doesn't wait, so this is just for code coverage await session.send_message((b'version', b'abc')) @pytest.mark.asyncio async def test_request_over_hard_limit(self, msg_server_port): async with connect_message_session('localhost', msg_server_port) as session: server = await MessageServer.current_server() server.bump_cost(server.cost_hard_limit + 100) await session.send_message((b'version', b'abc')) await sleep(0.05) assert session.is_closing() @pytest.mark.asyncio async def test_server_busy(self, msg_server_port, caplog): async with connect_message_session('localhost', msg_server_port) as session: server = await MessageServer.current_server() server.processing_timeout = 0.01 with caplog.at_level(logging.INFO): await session.send_message((b'sleep', b'')) await sleep(0.05) assert server.errors == 1 assert in_caplog(caplog, 'timed out') class TestConcurrency: def test_concurrency_constructor(self): Concurrency(3) Concurrency(target=6) Concurrency(target=0) with pytest.raises(ValueError): Concurrency(target=-1) @pytest.mark.asyncio async def test_concurrency_control(self): in_flight = 0 c = Concurrency(target=3) pause = 0.01 counter = 0 async def make_workers(): async def worker(): nonlocal in_flight, counter async with c: counter += 1 in_flight += 1 await sleep(pause) in_flight -= 1 async with TaskGroup() as group: for n in range(100): await group.spawn(worker) async def get_stable_in_flight(): nonlocal in_flight prior = in_flight while True: await sleep(0) if in_flight == prior: return in_flight prior = in_flight task = await spawn(make_workers) try: await sleep(0) assert await get_stable_in_flight() == 3 c.set_target(3) await sleep(pause * 1.1) assert await get_stable_in_flight() == 3 c.set_target(1) await sleep(pause * 1.1) assert await get_stable_in_flight() == 1 c.set_target(10) await sleep(pause * 1.1) assert await get_stable_in_flight() == 10 c.set_target(1) await sleep(pause * 1.1) assert await get_stable_in_flight() == 1 c.set_target(5) await sleep(pause * 1.1) assert await get_stable_in_flight() == 5 c.set_target(0) await sleep(pause * 1.1) assert await get_stable_in_flight() == 0 # We deliberately don't recover from 0. To do so set_target needs to release # once if existing value is zero. c.set_target(3) await sleep(pause * 1.1) assert await get_stable_in_flight() == 0 finally: task.cancel() with suppress(CancelledError, ExcessiveSessionCostError): await task @pytest.mark.asyncio async def test_retarget_accounting(self): c = Concurrency(target=2) async def worker(n): async with c: await sleep(0.001 * n) if n == 1: c.set_target(1) for n in range(1, 4): await spawn(worker, n) # Whilst this task sleeps, the sequence is: # Worker 1 grabs C and sleeps # Worker 2 grabs C and sleeps # Worker 3 cannot grab C, so sleeps # Worker 1 wakes up, sets target to 1, ends releasing C # Worker 3 wakes up, grabs C, and retargets before entering the context block. # It sleeps trying to acquire the semaphore. # Worker 2 wakes up, ends releasing C # Worker 3 wakes up, enters the context block, sleeps, and ends releasing C. # # This is a test that worker 3, when retargetting, decrements C._sem_value before # sleeping. If it doesn't, worker 2, on exiting its context block, thinks nothing else # nothing is trying to retarget the semaphore, and so reduces C._sem_value instead # of releasing the semaphore. This means that worker 3 never wakes up. await sleep(0.05) assert not c._semaphore.locked() aiorpcX-0.24/tests/test_socks.py000077500000000000000000000571341474217261100167630ustar00rootroot00000000000000import asyncio import ipaddress import os import struct from functools import partial from random import randrange import pytest from aiorpcx.socks import NeedData from aiorpcx.rawsocket import RSTransport from aiorpcx import ( RPCSession, NetAddress, SOCKS5, SOCKSProxy, SOCKS4a, SOCKS4, connect_rs, SOCKSProtocolError, SOCKSUserAuth, SOCKSFailure, SOCKSError, SOCKSRandomAuth, ) # TODO : Server tests - short and close, or just waiting no response GCOM = NetAddress('www.google.com', 80) IPv6 = NetAddress('::', 80) GDNS = NetAddress('8.8.8.8', 53) SOCKS4_addresses = (GDNS, ) SOCKS4a_addresses = (GDNS, GCOM) SOCKS5_addresses = (GDNS, GCOM, IPv6) auth_methods = [None, SOCKSUserAuth('user', 'pass')] @pytest.fixture(params=SOCKS4_addresses) def addr4(request): return request.param @pytest.fixture(params=set(SOCKS5_addresses) - set(SOCKS4_addresses)) def addr4_bad(request): return request.param @pytest.fixture(params=SOCKS4a_addresses) def addr4a(request): return request.param @pytest.fixture(params=[IPv6]) def addr4a_bad(request): return request.param @pytest.fixture(params=SOCKS5_addresses) def addr5(request): return request.param @pytest.fixture(params=auth_methods) def auth(request): return request.param @pytest.fixture(params=[0, 2]) def chosen_auth(request): return request.param class HangingError(Exception): pass class FakeResponder(object): def __init__(self, response): self.response = response self.messages = [] def send(self, message): self.messages.append(message) def read(self, count): assert count > 0 if count > len(self.response): raise HangingError count = randrange(0, count + 1) response = self.response[:count] self.response = self.response[count:] return response def run_communication(client, server): while True: try: message = client.next_message() except NeedData as e: data = server.read(e.args[0]) client.receive_data(data) continue if message is None: return server.messages server.send(message) class TestSOCKS4(object): @classmethod def response(cls): return bytes([0, 0x5a]) + os.urandom(6) def short_bytes(self): result = self.response() return result[:randrange(0, len(result) - 1)] def fail_bytes(self, code): return bytes([0, code]) + os.urandom(6) def bad_first_byte(self): return bytes([randrange(1, 256)]) + self.response()[1:] def test_good_response(self, addr4, auth): client = SOCKS4(addr4, auth) server = FakeResponder(self.response()) messages = run_communication(client, server) user_id = b'' if not auth else auth.username.encode() packed = addr4.host.packed data = b''.join((b'\4\1', struct.pack('>H', addr4.port), packed, user_id, b'\0')) assert messages == [data] def test_short_response(self, addr4, auth): client = SOCKS4(addr4, auth) server = FakeResponder(self.short_bytes()) with pytest.raises(HangingError): run_communication(client, server) def test_request_rejected_89(self, addr4, auth): client = SOCKS4(addr4, auth) server = FakeResponder(self.fail_bytes(89)) with pytest.raises(SOCKSFailure) as err: run_communication(client, server) assert 'unknown SOCKS4 reply code 89' in str(err.value) def test_request_rejected_91(self, addr4, auth): client = SOCKS4(addr4, auth) server = FakeResponder(self.fail_bytes(91)) with pytest.raises(SOCKSFailure) as err: run_communication(client, server) assert 'request rejected or failed' in str(err.value) def test_request_rejected_92(self, addr4, auth): client = SOCKS4(addr4, auth) server = FakeResponder(self.fail_bytes(92)) with pytest.raises(SOCKSFailure) as err: run_communication(client, server) assert 'cannot connect to identd' in str(err.value) def test_request_rejected_93(self, addr4, auth): client = SOCKS4(addr4, auth) server = FakeResponder(self.fail_bytes(93)) with pytest.raises(SOCKSFailure) as err: run_communication(client, server) assert 'report different' in str(err.value) def test_response_bad_first_byte(self, addr4, auth): client = SOCKS4(addr4, auth) server = FakeResponder(self.bad_first_byte()) with pytest.raises(SOCKSProtocolError) as err: run_communication(client, server) assert 'invalid SOCKS4 proxy response' in str(err.value) def test_rejects_others(self, addr4_bad, auth): with pytest.raises(SOCKSProtocolError) as err: SOCKS4(addr4_bad, auth) assert 'SOCKS4 requires an IPv4' in str(err.value) class TestSOCKS4a(object): @classmethod def response(cls): return bytes([0, 0x5a]) + os.urandom(6) def short_bytes(self): result = self.response() return result[:randrange(0, len(result) - 1)] def fail_bytes(self, code): return bytes([0, code]) + os.urandom(6) def bad_first_byte(self): return bytes([randrange(1, 256)]) + self.response()[1:] def test_good_response(self, addr4a, auth): client = SOCKS4a(addr4a, auth) server = FakeResponder(self.response()) messages = run_communication(client, server) user_id = b'' if not auth else auth.username.encode() if isinstance(addr4a.host, str): host_bytes = addr4a.host.encode() + b'\0' ip_packed = b'\0\0\0\1' else: host_bytes = b'' ip_packed = addr4a.host.packed expected = b''.join((b'\4\1', struct.pack('>H', addr4a.port), ip_packed, user_id, b'\0', host_bytes)) assert messages == [expected] def test_short_response(self, addr4a, auth): client = SOCKS4a(addr4a, auth) server = FakeResponder(self.short_bytes()) with pytest.raises(HangingError): run_communication(client, server) def test_request_rejected_89(self, addr4a, auth): client = SOCKS4a(addr4a, auth) server = FakeResponder(self.fail_bytes(89)) with pytest.raises(SOCKSFailure) as err: run_communication(client, server) assert 'unknown SOCKS4a reply code 89' in str(err.value) def test_request_rejected_91(self, addr4a, auth): client = SOCKS4a(addr4a, auth) server = FakeResponder(self.fail_bytes(91)) with pytest.raises(SOCKSFailure) as err: run_communication(client, server) assert 'request rejected or failed' in str(err.value) def test_request_rejected_92(self, addr4a, auth): client = SOCKS4a(addr4a, auth) server = FakeResponder(self.fail_bytes(92)) with pytest.raises(SOCKSFailure) as err: run_communication(client, server) assert 'cannot connect to identd' in str(err.value) def test_request_rejected_93(self, addr4a, auth): client = SOCKS4a(addr4a, auth) server = FakeResponder(self.fail_bytes(93)) with pytest.raises(SOCKSFailure) as err: run_communication(client, server) assert 'report different' in str(err.value) def test_response_bad_first_byte(self, addr4a, auth): client = SOCKS4a(addr4a, auth) server = FakeResponder(self.bad_first_byte()) with pytest.raises(SOCKSProtocolError) as err: run_communication(client, server) assert 'invalid SOCKS4a proxy response' in str(err.value) def test_rejects_others(self, addr4a_bad, auth): with pytest.raises(SOCKSProtocolError) as err: SOCKS4a(addr4a_bad, auth) assert 'SOCKS4a requires an IPv4' in str(err.value) class TestSOCKS5(object): @classmethod def host_packed(cls, host, port, client): result = b'\5\1\0' if client else b'\5\0\0' if isinstance(host, ipaddress.IPv4Address): result += bytes([1]) + host.packed elif isinstance(host, str): result += bytes([3, len(host)]) + host.encode() else: result += bytes([4]) + host.packed return result + struct.pack('>H', port) @classmethod def response(cls, chosen_auth, host, greeting=None, auth_response=None, conn_response=None): if greeting is None: greeting = bytes([5, chosen_auth]) if auth_response is None: if chosen_auth == 2: auth_response = b'\1\0' else: auth_response = b'' if conn_response is None: conn_response = cls.host_packed(host, randrange(0, 65536), False) return b''.join((greeting, auth_response, conn_response)) def test_good(self, auth, chosen_auth, addr5): if chosen_auth == 2 and auth is None: return client = SOCKS5(addr5, auth) server = FakeResponder(self.response(chosen_auth, addr5.host)) messages = run_communication(client, server) expected = [] if auth is not None: expected.append(bytes([5, 2, 0, 2])) else: expected.append(bytes([5, 1, 0])) if chosen_auth == 2: expected.append(bytes([1, len(auth.username)]) + auth.username.encode() + bytes([len(auth.password)]) + auth.password.encode()) expected.append(self.host_packed(addr5.host, addr5.port, True)) assert messages == expected def test_short_username(self, addr5): auth = SOCKSUserAuth(username='', password='password') with pytest.raises(SOCKSProtocolError) as err: SOCKS5(addr5, auth) assert 'username' in str(err.value) def test_long_username(self, addr5): auth = SOCKSUserAuth(username='u' * 256, password='password') with pytest.raises(SOCKSProtocolError) as err: SOCKS5(addr5, auth) assert 'username' in str(err.value) def test_short_password(self, addr5): auth = SOCKSUserAuth(username='username', password='') with pytest.raises(SOCKSProtocolError) as err: SOCKS5(addr5, auth) assert 'password has invalid length' in str(err.value) def test_long_password(self, addr5): auth = SOCKSUserAuth(username='username', password='p' * 256) with pytest.raises(SOCKSProtocolError) as err: SOCKS5(addr5, auth) assert 'password has invalid length' in str(err.value) def test_auth_failure(self, addr5): auth = auth_methods[1] client = SOCKS5(addr5, auth) auth_failure_bytes = self.response(2, addr5.host, auth_response=b'\1\xff') server = FakeResponder(auth_failure_bytes) with pytest.raises(SOCKSFailure) as err: run_communication(client, server) assert 'SOCKS5 proxy auth failure code: 255' in str(err.value) def test_reject_auth_methods(self, addr5): auth = auth_methods[1] client = SOCKS5(addr5, auth) reject_methods_bytes = self.response(2, addr5.host, greeting=b'\5\xff') server = FakeResponder(reject_methods_bytes) with pytest.raises(SOCKSFailure) as err: run_communication(client, server) assert 'SOCKS5 proxy rejected authentication methods' in str(err.value) def test_bad_proto_version(self, auth, addr5): client = SOCKS5(addr5, auth) bad_proto_bytes = self.response(2, addr5.host, greeting=b'\4\0') server = FakeResponder(bad_proto_bytes) with pytest.raises(SOCKSProtocolError) as err: run_communication(client, server) assert 'invalid SOCKS5 proxy response' in str(err.value) def test_short_greeting(self, auth, addr5): client = SOCKS5(addr5, auth) short_greeting_bytes = self.response(2, addr5.host, greeting=b'\5') server = FakeResponder(short_greeting_bytes) with pytest.raises(SOCKSError): run_communication(client, server) def test_short_auth_reply(self, addr5): auth = auth_methods[1] client = SOCKS5(addr5, auth) short_auth_bytes = self.response(2, addr5.host, auth_response=b'\1') server = FakeResponder(short_auth_bytes) with pytest.raises(SOCKSError): run_communication(client, server) def test_bad_auth_reply(self, addr5): auth = auth_methods[1] client = SOCKS5(addr5, auth) bad_auth_bytes = self.response(2, addr5.host, auth_response=b'\0\0') server = FakeResponder(bad_auth_bytes) with pytest.raises(SOCKSProtocolError) as err: run_communication(client, server) assert 'invalid SOCKS5 proxy auth response' in str(err.value) def test_bad_connection_response1(self, auth, chosen_auth, addr5): if chosen_auth == 2 and auth is None: return client = SOCKS5(addr5, auth) response = bytearray(self.host_packed(addr5.host, 50000, False)) response[0] = 4 # Should be 5 response = self.response(chosen_auth, addr5.host, conn_response=response) server = FakeResponder(response) with pytest.raises(SOCKSProtocolError) as err: run_communication(client, server) assert 'invalid SOCKS5 proxy response' in str(err.value) def test_bad_connection_response2(self, auth, chosen_auth, addr5): if chosen_auth == 2 and auth is None: return client = SOCKS5(addr5, auth) response = bytearray(self.host_packed(addr5.host, 50000, False)) response[2] = 1 # Should be 0 response = self.response(chosen_auth, addr5.host, conn_response=response) server = FakeResponder(response) with pytest.raises(SOCKSProtocolError) as err: run_communication(client, server) assert 'invalid SOCKS5 proxy response' in str(err.value) def test_bad_connection_response3(self, auth, chosen_auth, addr5): if chosen_auth == 2 and auth is None: return client = SOCKS5(addr5, auth) response = bytearray(self.host_packed(addr5.host, 50000, False)) response[3] = 2 # Should be 1, 3 or 4 response = self.response(chosen_auth, addr5.host, conn_response=response) server = FakeResponder(response) with pytest.raises(SOCKSProtocolError) as err: run_communication(client, server) assert 'invalid SOCKS5 proxy response' in str(err.value) def check_failure(self, auth, chosen_auth, addr5, code, msg): if chosen_auth == 2 and auth is None: return client = SOCKS5(addr5, auth) response = bytearray(self.host_packed(addr5.host, 50000, False)) # Various error codes response[1] = code response = self.response(chosen_auth, addr5.host, conn_response=response) server = FakeResponder(response) with pytest.raises(SOCKSFailure) as err: run_communication(client, server) assert msg in str(err.value) def test_error_code_1(self, auth, chosen_auth, addr5): self.check_failure(auth, chosen_auth, addr5, 1, 'general SOCKS server failure') def test_error_code_2(self, auth, chosen_auth, addr5): self.check_failure(auth, chosen_auth, addr5, 2, 'connection not allowed by ruleset') def test_error_code_3(self, auth, chosen_auth, addr5): self.check_failure(auth, chosen_auth, addr5, 3, 'network unreachable') def test_error_code_4(self, auth, chosen_auth, addr5): self.check_failure(auth, chosen_auth, addr5, 4, 'host unreachable') def test_error_code_5(self, auth, chosen_auth, addr5): self.check_failure(auth, chosen_auth, addr5, 5, 'connection refused') def test_error_code_6(self, auth, chosen_auth, addr5): self.check_failure(auth, chosen_auth, addr5, 6, 'TTL expired') def test_error_code_7(self, auth, chosen_auth, addr5): self.check_failure(auth, chosen_auth, addr5, 7, 'command not supported') def test_error_code_8(self, auth, chosen_auth, addr5): self.check_failure(auth, chosen_auth, addr5, 8, 'address type not supported') def test_error_code_9(self, auth, chosen_auth, addr5): self.check_failure(auth, chosen_auth, addr5, 9, 'unknown SOCKS5 error code: 9') def test_short_req_reply(self, auth, chosen_auth, addr5): if chosen_auth == 2 and auth is None: return client = SOCKS5(addr5, auth) response = self.response(chosen_auth, addr5.host)[:-1] server = FakeResponder(response) with pytest.raises(HangingError): run_communication(client, server) class FakeServer(asyncio.Protocol): response = None def connection_made(self, transport): self.transport = transport def data_received(self, data): self.transport.write(self.response) localhosts = ['127.0.0.1', '::1', 'localhost'] @pytest.fixture(params=localhosts) def proxy_address(request, event_loop, unused_tcp_port): host = request.param coro = event_loop.create_server(FakeServer, host=host, port=unused_tcp_port) server = event_loop.run_until_complete(coro) yield NetAddress(host, unused_tcp_port) server.close() def local_hosts(host): if host == 'localhost': return ['::1', '127.0.0.1'] return [str(host)] class TestSOCKSProxy(object): @pytest.mark.asyncio async def test_good_SOCKS5(self, proxy_address, auth): chosen_auth = 2 if auth else 0 FakeServer.response = TestSOCKS5.response(chosen_auth, 'wwww.apple.com') result = await SOCKSProxy.auto_detect_at_address(proxy_address, auth) assert isinstance(result, SOCKSProxy) assert result.protocol is SOCKS5 assert result.address == proxy_address assert result.auth == auth assert result.peername[0] in local_hosts(proxy_address.host) assert result.peername[1] == proxy_address.port @pytest.mark.asyncio async def test_good_SOCKS4a(self, proxy_address, auth): FakeServer.response = TestSOCKS4a.response() result = await SOCKSProxy.auto_detect_at_address(proxy_address, auth) assert isinstance(result, SOCKSProxy) assert result.protocol is SOCKS4a assert result.address == proxy_address assert result.auth == auth assert result.peername[0] in local_hosts(proxy_address.host) assert result.peername[1] == proxy_address.port @pytest.mark.asyncio async def test_good_SOCKS4(self, proxy_address, auth): FakeServer.response = TestSOCKS4.response() result = await SOCKSProxy.auto_detect_at_address(proxy_address, auth) assert isinstance(result, SOCKSProxy) # FIXME: how to actually distinguish SOCKS4 and SOCKS4a? assert result.protocol is SOCKS4a assert result.address == proxy_address assert result.auth == auth assert result.peername[0] in local_hosts(proxy_address.host) assert result.peername[1] == proxy_address.port # @pytest.mark.asyncio # async def test_auto_detect_at_address_failure(self): # result = await SOCKSProxy.auto_detect_at_address('8.8.8.8:53', None) # assert result is None @pytest.mark.asyncio async def test_auto_detect_at_address_cannot_connect(self): result = await SOCKSProxy.auto_detect_at_address('localhost:1', None) assert result is None @pytest.mark.asyncio async def test_autodetect_at_host_success(self, proxy_address, auth): chosen_auth = 2 if auth else 0 FakeServer.response = TestSOCKS5.response(chosen_auth, 'wwww.apple.com') result = await SOCKSProxy.auto_detect_at_host(proxy_address.host, [proxy_address.port], auth) assert isinstance(result, SOCKSProxy) assert result.protocol is SOCKS5 assert result.address == proxy_address assert result.auth == auth assert result.peername[0] in local_hosts(proxy_address.host) assert result.peername[1] == proxy_address.port @pytest.mark.asyncio async def test_autodetect_at_host_failure(self, auth): ports = [1, 2] chosen_auth = 2 if auth else 0 FakeServer.response = TestSOCKS5.response(chosen_auth, 'wwww.apple.com') result = await SOCKSProxy.auto_detect_at_host('localhost', ports, auth) assert result is None @pytest.mark.asyncio async def test_create_connection_connect_failure(self, auth): proxy = SOCKSProxy('localhost:1', SOCKS5, auth) with pytest.raises(OSError): await proxy.create_connection(None, GCOM.host, GCOM.port) @pytest.mark.asyncio async def test_create_connection_good(self, proxy_address, auth): chosen_auth = 2 if auth else 0 FakeServer.response = TestSOCKS5.response(chosen_auth, 'wwww.apple.com') proxy = SOCKSProxy(proxy_address, SOCKS5, auth) async with connect_rs(GCOM.host, GCOM.port, proxy) as session: assert session.remote_address() == GCOM assert session.proxy() is proxy assert proxy.peername[0] in local_hosts(proxy_address.host) assert proxy.peername[1] == proxy_address.port assert isinstance(session, RPCSession) @pytest.mark.asyncio async def test_create_connection_resolve_good(self, proxy_address, auth): chosen_auth = 2 if auth else 0 proxy = SOCKSProxy(proxy_address, SOCKS5, auth) FakeServer.response = TestSOCKS5.response(chosen_auth, 'wwww.apple.com') async with connect_rs(GCOM.host, GCOM.port, proxy, resolve=True) as session: assert session.remote_address().host not in (None, GCOM.host) assert session.remote_address().port == GCOM.port assert proxy.peername[0] in local_hosts(proxy_address.host) assert proxy.peername[1] == proxy_address.port assert isinstance(session, RPCSession) @pytest.mark.asyncio async def test_create_connection_resolve_bad(self, proxy_address, auth): protocol_factory = partial(RSTransport, RPCSession, 'client') proxy = SOCKSProxy(proxy_address, SOCKS5, auth) with pytest.raises(OSError): await proxy.create_connection(protocol_factory, 'foobar.onion', 80, resolve=True) def test_str(self): address = NetAddress('localhost', 80) p = SOCKSProxy(address, SOCKS4a, None) assert str(p) == f'SOCKS4a proxy at {address}, auth: none' address = NetAddress('www.google.com', 8080) p = SOCKSProxy(address, SOCKS5, auth_methods[1]) assert str(p) == f'SOCKS5 proxy at {address}, auth: username' def test_random(self): auth1 = auth_methods[1] auth2 = SOCKSRandomAuth() # SOCKSRandomAuth is a SOCKSUserAuth assert isinstance(auth2, SOCKSUserAuth) # Username of SOCKSUserAuth should be constant user1a = auth1.username user1b = auth1.username assert user1a == user1b # Password of SOCKSUserAuth should be constant pass1a = auth1.password pass1b = auth1.password assert pass1a == pass1b # Username of SOCKSRandomAuth should be random user2a = auth2.username user2b = auth2.username assert user2a != user2b # Password of SOCKSRandomAuth should be random pass2a = auth2.password pass2b = auth2.password assert pass2a != pass2b def test_basic(): assert issubclass(SOCKSProtocolError, SOCKSError) assert issubclass(SOCKSFailure, SOCKSError) aiorpcX-0.24/tests/test_unixsocket.py000077500000000000000000000027261474217261100200320ustar00rootroot00000000000000import sys import asyncio import pytest import tempfile from os import path from aiorpcx import connect_us, serve_us from test_session import MyServerSession if sys.platform.startswith("win"): pytest.skip("skipping tests not compatible with Windows platform", allow_module_level=True) @pytest.fixture def us_server(event_loop): with tempfile.TemporaryDirectory() as tmp_folder: socket_path = path.join(tmp_folder, 'test.socket') coro = serve_us(MyServerSession, socket_path, loop=event_loop) server = event_loop.run_until_complete(coro) yield socket_path tasks = asyncio.all_tasks(event_loop) async def close_all(): server.close() await server.wait_closed() if tasks: await asyncio.wait(tasks) event_loop.run_until_complete(close_all()) class TestUSTransport: @pytest.mark.asyncio async def test_send_request(self, us_server): async with connect_us(us_server) as session: assert await session.send_request('echo', [23]) == 23 @pytest.mark.asyncio async def test_is_closing(self, us_server): async with connect_us(us_server) as session: assert not session.is_closing() await session.close() assert session.is_closing() async with connect_us(us_server) as session: assert not session.is_closing() await session.abort() assert session.is_closing() aiorpcX-0.24/tests/test_util.py000077500000000000000000000277661474217261100166260ustar00rootroot00000000000000import asyncio from functools import partial from ipaddress import IPv4Address, IPv6Address import pytest from aiorpcx.util import ( is_async_call, is_valid_hostname, validate_port, validate_protocol, classify_host, Service, NetAddress, ServicePart, ) async def coro(x, y): pass def test_is_async_call(): z = coro(2, 3) assert not is_async_call(z) assert is_async_call(coro) assert is_async_call(partial(coro, 3, 4)) assert is_async_call(partial(partial(coro, 3), 4)) assert not is_async_call(test_is_async_call) assert not is_async_call(partial(is_async_call)) # Lose a warning asyncio.get_event_loop().run_until_complete(z) @pytest.mark.parametrize("hostname,answer", ( ('', False), ('a', True), ('_', True), # Hyphens ('-b', False), ('a.-b', False), ('a-b', True), ('b-', False), ('b-.c', False), # Dots ('a.', True), ('a..', False), ('foo1.Foo', True), ('foo1..Foo', False), ('12Foo.Bar.Bax_', True), ('12Foo.Bar.Baz_12', True), # Numeric TLD ('foo1.123', False), ('foo1.d123', True), ('foo1.123d', True), # IP Addresses ('1.2.3.4', False), ('12::23', False), # 63 octets in part ('a.abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_.bar', True), # Over 63 octets in part ('a.abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_1.bar', False), # Length (('a' * 62 + '.') * 4 + 'a', True), # 253 (('a' * 62 + '.') * 4 + 'ab', False), # 254 )) def test_is_valid_hostname(hostname, answer): assert is_valid_hostname(hostname) == answer @pytest.mark.parametrize("hostname", (2, b'1.2.3.4')) def test_is_valid_hostname_bad(hostname): with pytest.raises(TypeError): is_valid_hostname(hostname) @pytest.mark.parametrize("host,answer", ( ('1.2.3.4', IPv4Address('1.2.3.4')), ('12:32::', IPv6Address('12:32::')), (IPv4Address('8.8.8.8'), IPv4Address('8.8.8.8')), (IPv6Address('::1'), IPv6Address('::1')), ('foo.bar.baz.', 'foo.bar.baz.'), )) def test_classify_host(host, answer): assert classify_host(host) == answer @pytest.mark.parametrize("host", (2, b'1.2.3.4')) def test_classify_host_bad_type(host): with pytest.raises(TypeError): classify_host(host) @pytest.mark.parametrize("host", ('', 'a..', 'b-', 'a' * 64)) def test_classify_host_bad(host): with pytest.raises(ValueError): classify_host(host) class TestNetAddress: @pytest.mark.parametrize("host,port,answer,host_type", ( ('foo.bar', '23', 'foo.bar:23', str), ('foo.bar', 23, 'foo.bar:23', str), ('foo.bar', 23.0, TypeError, None), ('::1', 15, '[::1]:15', IPv6Address), ('5.6.7.8', '23', '5.6.7.8:23', IPv4Address), ('5.6.7.8.9', '23', ValueError, None), ('[::1]', '23', ValueError, None), ('[::1]', 0, ValueError, None), ('[::1]', 65536, ValueError, None), )) def test_constructor(self, host, port, answer, host_type): if isinstance(answer, type) and issubclass(answer, Exception): with pytest.raises(answer): NetAddress(host, port) else: address = NetAddress(host, port) assert str(address) == answer assert isinstance(address.host, host_type) def test_eq(self): assert NetAddress('1.2.3.4', 23) == NetAddress('1.2.3.4', 23) assert NetAddress('1.2.3.4', 23) == NetAddress('1.2.3.4', '23') assert NetAddress('1.2.3.4', 23) != NetAddress('1.2.3.4', 24) assert NetAddress('1.2.3.4', 24) != NetAddress('1.2.3.5', 24) assert NetAddress('foo.bar', 24) != NetAddress('foo.baz', 24) def test_hashable(self): assert len({NetAddress('1.2.3.4', 23), NetAddress('1.2.3.4', '23')}) == 1 @pytest.mark.parametrize("host,port,answer", ( ('foo.bar', '23', "NetAddress('foo.bar', 23)"), ('foo.bar', 23, "NetAddress('foo.bar', 23)"), ('::1', 15, "NetAddress(IPv6Address('::1'), 15)"), ('5.6.7.8', '23', "NetAddress(IPv4Address('5.6.7.8'), 23)"), )) def test_repr(self, host, port, answer): assert repr(NetAddress(host, port)) == answer @pytest.mark.parametrize("string,default_func,answer", ( ('foo.bar:23', None, NetAddress('foo.bar', 23)), (':23', NetAddress.default_host('localhost'), NetAddress('localhost', 23)), (':23', None, ValueError), (':23', NetAddress.default_port(23), ValueError), ('foo.bar', NetAddress.default_port(500), NetAddress('foo.bar', 500)), ('foo.bar:', NetAddress.default_port(500), NetAddress('foo.bar', 500)), ('foo.bar', NetAddress.default_port(500), NetAddress('foo.bar', 500)), (':', NetAddress.default_host_and_port('localhost', 80), NetAddress('localhost', 80)), ('::1:', None, ValueError), ('::1', None, ValueError), ('[::1:22', None, ValueError), ('[::1]:22', NetAddress.default_port(500), NetAddress('::1', 22)), ('[::1]:', NetAddress.default_port(500), NetAddress('::1', 500)), ('[::1]', NetAddress.default_port(500), NetAddress('::1', 500)), ('1.2.3.4:22', None, NetAddress('1.2.3.4', 22)), ('1.2.3.4:', NetAddress.default_port(500), NetAddress('1.2.3.4', 500)), ('1.2.3.4', NetAddress.default_port(500), NetAddress('1.2.3.4', 500)), ('localhost', NetAddress.default_port(500), NetAddress('localhost', 500)), ('1.2.3.4', NetAddress.default_host('localhost'), ValueError), (2, None, TypeError), (b'', None, TypeError), )) def test_from_string(self, string, default_func, answer): if isinstance(answer, type) and issubclass(answer, Exception): with pytest.raises(answer): NetAddress.from_string(string, default_func=default_func) else: assert NetAddress.from_string(string, default_func=default_func) == answer @pytest.mark.parametrize("address,answer", ( (NetAddress('foo.bar', 23), 'foo.bar:23'), (NetAddress('abcd::dbca', 40), '[abcd::dbca]:40'), (NetAddress('1.2.3.5', 50000), '1.2.3.5:50000'), )) def test_str(self, address, answer): assert str(address) == answer @pytest.mark.parametrize("attr", ('host', 'port')) def test_immutable(self, attr): address = NetAddress('foo.bar', 23) with pytest.raises(AttributeError): setattr(address, attr, 'foo') setattr(address, 'foo', '') class TestService: @pytest.mark.parametrize("protocol,address,answer", ( ('tcp', 'domain.tld:8000', Service('tcp', NetAddress('domain.tld', 8000))), ('SSL', NetAddress('domain.tld', '23'), Service('ssl', NetAddress('domain.tld', 23))), ('SSL', '[::1]:80', Service('SSL', NetAddress('::1', 80))), ('ws', '1.2.3.4:80', Service('ws', NetAddress('1.2.3.4', 80))), (4, '1.2.3.4:80', TypeError), ('wss', '1.2.3.4:', ValueError), )) def test_constructor(self, protocol, address, answer): if isinstance(answer, type) and issubclass(answer, Exception): with pytest.raises(answer): Service(protocol, address) else: assert Service(protocol, address) == answer def test_eq(self): assert Service('http', '1.2.3.4:23') == Service( 'HTTP', NetAddress(IPv4Address('1.2.3.4'), 23)) assert Service('https', '1.2.3.4:23') != Service('http', '1.2.3.4:23') assert Service('https', '1.2.3.4:23') != Service('https', '1.2.3.4:22') def test_hashable(self): assert 1 == len({Service('http', '1.2.3.4:23'), Service('HTTP', NetAddress(IPv4Address('1.2.3.4'), 23))}) @pytest.mark.parametrize("protocol,address,answer", ( ('TCP', 'foo.bar:23', 'tcp://foo.bar:23'), ('httpS', NetAddress('::1', 80), 'https://[::1]:80'), ('ws', NetAddress('1.2.3.4', '50000'), 'ws://1.2.3.4:50000'), )) def test_str(self, protocol, address, answer): assert str(Service(protocol, address)) == answer @pytest.mark.parametrize("protocol, address, answer", ( ('TCP', 'foo.bar:23', "Service('tcp', 'foo.bar:23')"), ('httpS', NetAddress('::1', 80), "Service('https', '[::1]:80')"), ('ws', NetAddress('1.2.3.4', '50000'), "Service('ws', '1.2.3.4:50000')"), )) def test_repr(self, protocol, address, answer): assert repr(Service(protocol, address)) == answer def test_attributes(self): service = Service('HttpS', '[::1]:80') assert service.protocol == 'https' assert service.address == NetAddress('::1', 80) assert service.host == IPv6Address('::1') assert service.port == 80 def default_func(protocol, kind): if kind == ServicePart.PROTOCOL: return 'SSL' if kind == ServicePart.HOST: return {'ssl': 'ssl_host.tld', 'tcp': 'tcp_host.tld'}.get(protocol) return {'ssl': 443, 'tcp': '80', 'ws': 50001}.get(protocol) @pytest.mark.parametrize("service,default_func,answer", ( ('HTTP://foo.BAR:80', None, Service('http', NetAddress('foo.BAR', 80))), ('ssl://[::1]:80', None, Service('ssl', '[::1]:80')), ('ssl://5.6.7.8:50001', None, Service('ssl', NetAddress('5.6.7.8', 50001))), ('ssl://foo.bar', None, ValueError), ('ssl://:80', None, ValueError), ('foo.bar:80', None, ValueError), ('foo.bar', None, ValueError), (2, None, TypeError), # With default funcs ('localhost:80', default_func, Service('ssl', 'localhost:80')), ('localhost', default_func, Service('ssl', 'localhost:443')), ('WS://domain.tld', default_func, Service('ws', 'domain.tld:50001')), # TCP has a default host and port ('tcp://localhost', default_func, Service('tcp', 'localhost:80')), ('tcp://:', default_func, Service('tcp', 'tcp_host.tld:80')), ('tcp://', default_func, Service('tcp', 'tcp_host.tld:80')), # As TCP has a default host and port it is interpreted as a protocol not a host ('tcp', default_func, Service('tcp', 'tcp_host.tld:80')), # WS has no default host ('ws://', default_func, ValueError), ('ws://:45', default_func, ValueError), ('ws://localhost', default_func, Service('ws', 'localhost:50001')), # WS alone is interpreted as a host name as WS protocol has no default host ('ws', default_func, Service('ssl', 'ws:443')), # Default everything ('', default_func, Service('ssl', 'ssl_host.tld:443')), )) def test_from_string(self, service, default_func, answer): if isinstance(answer, type) and issubclass(answer, Exception): with pytest.raises(answer): Service.from_string(service, default_func=default_func) else: assert Service.from_string(service, default_func=default_func) == answer @pytest.mark.parametrize("attr", ('host', 'port', 'address', 'protocol')) def test_immutable(self, attr): service = Service.from_string('https://foo.bar:8000') with pytest.raises(AttributeError): setattr(service, attr, '') setattr(service, 'foo', '') @pytest.mark.parametrize("port,answer", ( ('2', 2), (65535, 65535), (0, ValueError), (-1, ValueError), (65536, ValueError), (b'', TypeError), (2.0, TypeError), ('2a', ValueError), )) def test_validate_port(port, answer): if isinstance(answer, type) and issubclass(answer, Exception): with pytest.raises(answer): validate_port(port) else: assert validate_port(port) == answer @pytest.mark.parametrize("protocol,answer", ( ('TCP', 'tcp'), ('http', 'http'), ('Ftp.-xbar+', 'ftp.-xbar+'), (b'', TypeError), (2, TypeError), ('', ValueError), ('a@b', ValueError), ('a:b', ValueError), ('[23]', ValueError), )) def test_validate_protocol(protocol, answer): if isinstance(answer, type) and issubclass(answer, Exception): with pytest.raises(answer): validate_protocol(protocol) else: assert validate_protocol(protocol) == answer aiorpcX-0.24/tests/test_websocket.py000077500000000000000000000027111474217261100176160ustar00rootroot00000000000000import pytest from aiorpcx import connect_ws, NetAddress, serve_ws from test_session import MyServerSession @pytest.fixture(scope="function") async def ws_server(unused_tcp_port): server = await serve_ws(MyServerSession, 'localhost', unused_tcp_port) yield f'ws://localhost:{unused_tcp_port}' server.close() await server.wait_closed() @pytest.mark.filterwarnings("ignore:'with .*:DeprecationWarning") class TestWSTransport: @pytest.mark.asyncio async def test_send_request(self, ws_server): async with connect_ws(ws_server) as session: assert await session.send_request('echo', [23]) == 23 @pytest.mark.asyncio async def test_basics(self, ws_server): async with connect_ws(ws_server) as session: assert session.proxy() is None remote_address = session.remote_address() assert isinstance(remote_address, NetAddress) assert str(remote_address.host) in ('localhost', '::1', '127.0.0.1') assert ws_server.endswith(str(remote_address.port)) @pytest.mark.asyncio async def test_is_closing(self, ws_server): async with connect_ws(ws_server) as session: assert not session.is_closing() await session.close() assert session.is_closing() async with connect_ws(ws_server) as session: assert not session.is_closing() await session.abort() assert session.is_closing() aiorpcX-0.24/tests/util.py000077500000000000000000000026411474217261100155500ustar00rootroot00000000000000from aiorpcx import ProtocolError, RPCError def assert_ProtocolError(exception, code, message): assert isinstance(exception, ProtocolError), \ f'expected ProtocolError got {exception.__class__.__name__}' assert exception.code == code, \ f'expected {code} got {exception.code}' if message: assert message in exception.message, \ f'{message} not in {exception.message}' def assert_RPCError(exception, code, message): assert isinstance(exception, RPCError), \ f'expected RPCError got {exception.__class__.__name__}' assert exception.code == code, \ f'expected {code} got {exception.code}' if message: assert message in exception.message, \ f'{message} not in {exception.message}' class RaiseTest(object): def __init__(self, code, message, exc_type): self.code = code self.message = message self.exc_type = exc_type def __enter__(self): return self def __exit__(self, exc_type, exc_value, traceback): assert exc_type is self.exc_type, \ f'expected {self.exc_type} got {exc_type}' assert exc_value.code == self.code, \ f'expected {self.code} got {exc_value.code}' if self.message: assert self.message in exc_value.message, \ f'{self.message} not in {exc_value.message}' self.value = exc_value return True