From 6abb31e46919d68e0b393a52a09269ae06e55a97 Mon Sep 17 00:00:00 2001 From: Mark Qvist Date: Sun, 26 Apr 2026 22:24:00 +0200 Subject: [PATCH] Added rnsh to included utilities --- RNS/Utilities/rnsh/__init__.py | 29 ++ RNS/Utilities/rnsh/_version.py | 1 + RNS/Utilities/rnsh/exception.py | 26 ++ RNS/Utilities/rnsh/helpers.py | 25 ++ RNS/Utilities/rnsh/initiator.py | 474 ++++++++++++++++++++ RNS/Utilities/rnsh/listener.py | 219 +++++++++ RNS/Utilities/rnsh/loop.py | 12 + RNS/Utilities/rnsh/process.py | 773 ++++++++++++++++++++++++++++++++ RNS/Utilities/rnsh/protocol.py | 115 +++++ RNS/Utilities/rnsh/retry.py | 189 ++++++++ RNS/Utilities/rnsh/rnsh.py | 161 +++++++ RNS/Utilities/rnsh/session.py | 407 +++++++++++++++++ 12 files changed, 2431 insertions(+) create mode 100644 RNS/Utilities/rnsh/__init__.py create mode 100644 RNS/Utilities/rnsh/_version.py create mode 100644 RNS/Utilities/rnsh/exception.py create mode 100644 RNS/Utilities/rnsh/helpers.py create mode 100644 RNS/Utilities/rnsh/initiator.py create mode 100644 RNS/Utilities/rnsh/listener.py create mode 100644 RNS/Utilities/rnsh/loop.py create mode 100644 RNS/Utilities/rnsh/process.py create mode 100644 RNS/Utilities/rnsh/protocol.py create mode 100644 RNS/Utilities/rnsh/retry.py create mode 100644 RNS/Utilities/rnsh/rnsh.py create mode 100644 RNS/Utilities/rnsh/session.py diff --git a/RNS/Utilities/rnsh/__init__.py b/RNS/Utilities/rnsh/__init__.py new file mode 100644 index 0000000..cc07f45 --- /dev/null +++ b/RNS/Utilities/rnsh/__init__.py @@ -0,0 +1,29 @@ +# MIT License +# +# Copyright (c) 2023 Aaron Heise +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from ._version import __version__ + +import os +module_abs_filename = os.path.abspath(__file__) +module_dir = os.path.dirname(module_abs_filename) + +def _get_version(): return __version__ diff --git a/RNS/Utilities/rnsh/_version.py b/RNS/Utilities/rnsh/_version.py new file mode 100644 index 0000000..49f34f4 --- /dev/null +++ b/RNS/Utilities/rnsh/_version.py @@ -0,0 +1 @@ +__version__ = "0.2.0" \ No newline at end of file diff --git a/RNS/Utilities/rnsh/exception.py b/RNS/Utilities/rnsh/exception.py new file mode 100644 index 0000000..0a2a178 --- /dev/null +++ b/RNS/Utilities/rnsh/exception.py @@ -0,0 +1,26 @@ +import contextlib +from contextlib import AbstractContextManager +import logging +import sys + + +class permit(AbstractContextManager): + """Context manager to allow specified exceptions + + The specified exceptions will be allowed to bubble up. Other + exceptions are suppressed. + + After a non-matching exception is suppressed, execution proceeds + with the next statement following the with statement. + + with allow(KeyboardInterrupt): + time.sleep(300) + # Execution still resumes here if no KeyboardInterrupt + """ + + def __init__(self, *exceptions): self._exceptions = exceptions + + def __enter__(self): pass + + def __exit__(self, exctype, excinst, exctb): + return exctype is not None and not issubclass(exctype, self._exceptions) diff --git a/RNS/Utilities/rnsh/helpers.py b/RNS/Utilities/rnsh/helpers.py new file mode 100644 index 0000000..35033b8 --- /dev/null +++ b/RNS/Utilities/rnsh/helpers.py @@ -0,0 +1,25 @@ +import asyncio +import time + +def bitwise_or_if(value: int, condition: bool, orval: int): + if not condition: return value + return value | orval + +def check_and(value: int, andval: int) -> bool: + return (value & andval) > 0 + +class SleepRate: + def __init__(self, target_period: float): + self.target_period = target_period + self.last_wake = time.time() + + def next_sleep_time(self) -> float: + old_last_wake = self.last_wake + self.last_wake = time.time() + next_wake = max(old_last_wake + 0.01, self.last_wake) + sleep_for = next_wake - self.last_wake + return sleep_for if sleep_for > 0 else 0 + + async def sleep_async(self): await asyncio.sleep(self.next_sleep_time()) + + def sleep_block(self): time.sleep(self.next_sleep_time()) diff --git a/RNS/Utilities/rnsh/initiator.py b/RNS/Utilities/rnsh/initiator.py new file mode 100644 index 0000000..96df290 --- /dev/null +++ b/RNS/Utilities/rnsh/initiator.py @@ -0,0 +1,474 @@ +#!/usr/bin/env python3 + +# MIT License +# +# Copyright (c) 2016-2022 Mark Qvist / unsigned.io +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations + +import asyncio +import base64 +import enum +import functools +import os +import queue +import shlex +import signal +import sys +import termios +import threading +import time +import tty +from typing import Callable, TypeVar +import RNS +import rnsh.exception as exception +import rnsh.process as process +import rnsh.retry as retry +import rnsh.session as session +import re +import contextlib +import rnsh.args +import pwd +import bz2 +import rnsh.protocol as protocol +import rnsh.helpers as helpers +import rnsh.rnsh + +_identity = None +_reticulum = None +_cmd: [str] | None = None +DATA_AVAIL_MSG = "data available" +_finished: asyncio.Event = None +_retry_timer: retry.RetryThread | None = None +_destination: RNS.Destination | None = None +_loop: asyncio.AbstractEventLoop | None = None + + +async def _check_finished(timeout: float = 0): + return _finished is not None and await process.event_wait(_finished, timeout=timeout) + +def _sigint_handler(sig, loop): + global _finished + RNS.log(f"{signal.Signals(sig).name}", RNS.LOG_DEBUG) + if _finished is not None: _finished.set() + else: raise KeyboardInterrupt() + +async def _spin_tty(until=None, msg=None, timeout=None): + i = 0 + syms = "⢄⢂⢁⡁⡈⡐⡠" + if timeout != None: timeout = time.time()+timeout + + print(msg+" ", end=" ") + while (timeout == None or time.time() timeout: return False + else: return True + + +async def _spin_pipe(until: callable = None, msg=None, timeout: float | None = None) -> bool: + if timeout is not None: timeout += time.time() + + while (timeout is None or time.time() < timeout) and not until(): + if await _check_finished(0.1): raise asyncio.CancelledError() + + if timeout is not None and time.time() > timeout: return False + else: return True + +async def _spin(until: callable = None, msg=None, timeout: float | None = None, quiet: bool = False) -> bool: + if not quiet and os.isatty(1): return await _spin_tty(until, msg, timeout) + else: return await _spin_pipe(until, msg, timeout) + +_link: RNS.Link | None = None +_remote_exec_grace = 2.0 +_pq = queue.Queue() + + +class InitiatorState(enum.IntEnum): + IS_INITIAL = 0 + IS_LINKED = 1 + IS_WAIT_VERS = 2 + IS_RUNNING = 3 + IS_TERMINATE = 4 + IS_TEARDOWN = 5 + +def _client_link_closed(link): + if _finished: _finished.set() + +def _client_message_handler(message: RNS.MessageBase): _pq.put(message) + +def compute_target_rns_loglevel(verbosity: int, quietness: int, base_level: int = RNS.LOG_INFO) -> int: + try: + target = int(base_level) + int(verbosity) - int(quietness) + if target < RNS.LOG_CRITICAL: target = RNS.LOG_CRITICAL + if target > RNS.LOG_DEBUG: target = RNS.LOG_DEBUG + return target + + except Exception: return base_level + + +class RemoteExecutionError(Exception): + def __init__(self, msg): self.msg = msg + + +async def _initiate_link(configdir, rnsconfigdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None, + timeout=RNS.Transport.PATH_REQUEST_TIMEOUT): + global _identity, _reticulum, _link, _destination, _remote_exec_grace + + dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2 + if len(destination) != dest_len: + raise RemoteExecutionError( + "Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format( + hex=dest_len, byte=dest_len // 2)) + try: + destination_hash = bytes.fromhex(destination) + except Exception as e: + raise RemoteExecutionError("Invalid destination entered. Check your input.") + + if _reticulum is None: + targetloglevel = compute_target_rns_loglevel(verbosity, quietness, RNS.LOG_ERROR) + RNS.logfile = os.path.join(configdir, "logfile") + _reticulum = RNS.Reticulum(configdir=rnsconfigdir, loglevel=targetloglevel, logdest=RNS.LOG_FILE) + + if _identity is None: + _identity = rnsh.rnsh.prepare_identity(identitypath) + + if not RNS.Transport.has_path(destination_hash): + RNS.Transport.request_path(destination_hash) + RNS.log(f"Requesting path...", RNS.LOG_INFO) + if not await _spin(until=lambda: RNS.Transport.has_path(destination_hash), msg="Requesting path...", + timeout=timeout, quiet=quietness > 0): + raise RemoteExecutionError("Path not found") + + if _destination is None: + listener_identity = RNS.Identity.recall(destination_hash) + _destination = RNS.Destination( + listener_identity, + RNS.Destination.OUT, + RNS.Destination.SINGLE, + rnsh.rnsh.APP_NAME + ) + + if _link is None or _link.status == RNS.Link.PENDING: + RNS.log("No link", RNS.LOG_DEBUG) + _link = RNS.Link(_destination) + _link.did_identify = False + + _link.set_link_closed_callback(_client_link_closed) + + RNS.log(f"Establishing link...", RNS.LOG_VERBOSE) + if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, msg="Establishing link...", + timeout=timeout, quiet=quietness > 0): + raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash)) + + RNS.log("Have link", RNS.LOG_DEBUG) + if not noid and not _link.did_identify: + # Delay a tiny bit to allow listener to fully enter WAIT_IDENT state + await asyncio.sleep(min(1, _link.rtt * 1.1 + 0.05)) + _link.identify(_identity) + _link.did_identify = True + + +async def _handle_error(errmsg: RNS.MessageBase): + if isinstance(errmsg, protocol.ErrorMessage): + with contextlib.suppress(Exception): + if _link and _link.status == RNS.Link.ACTIVE: + _link.teardown() + await asyncio.sleep(0.1) + raise RemoteExecutionError(f"Remote error: {errmsg.msg}") + + +async def initiate(configdir: str, rnsconfigdir:str, identitypath: str, verbosity: int, quietness: int, noid: bool, destination: str, + timeout: float, command: [str] | None = None): + global _finished, _link + with process.TTYRestorer(sys.stdin.fileno()) as ttyRestorer: + loop = asyncio.get_running_loop() + state = InitiatorState.IS_INITIAL + data_buffer = bytearray(sys.stdin.buffer.read()) if not os.isatty(sys.stdin.fileno()) else bytearray() + line_buffer = bytearray() + + await _initiate_link(configdir=configdir, + rnsconfigdir=rnsconfigdir, + identitypath=identitypath, + verbosity=verbosity, + quietness=quietness, + noid=noid, + destination=destination, + timeout=timeout) + + if not _link or _link.status not in [RNS.Link.ACTIVE, RNS.Link.PENDING]: + return 255 + + state = InitiatorState.IS_LINKED + outlet = session.RNSOutlet(_link) + channel = _link.get_channel() + protocol.register_message_types(channel) + channel.add_message_handler(_client_message_handler) + + # Next step after linking and identifying: send version + # if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5, quiet=quietness > 0): + # print("Error bringing up link") + # return 253 + + channel.send(protocol.VersionInfoMessage()) + try: + vm = _pq.get(timeout=max(outlet.rtt * 20, 5)) + await _handle_error(vm) + if not isinstance(vm, protocol.VersionInfoMessage): + raise Exception("Invalid message received") + RNS.log(f"Server version info: sw {vm.sw_version} prot {vm.protocol_version}", RNS.LOG_DEBUG) + state = InitiatorState.IS_RUNNING + except queue.Empty: + print("Protocol error") + return 254 + + winch = False + def sigwinch_handler(): + nonlocal winch + winch = True + + esc = False + pre_esc = True + line_mode = False + line_flush = False + blind_write_count = 0 + flush_chars = ["\x01", "\x03", "\x04", "\x05", "\x0c", "\x11", "\x13", "\x15", "\x19", "\t", "\x1A", "\x1B"] + def handle_escape(b): + nonlocal line_mode + if b == "?": + os.write(1, "\n\r\n\rSupported rnsh escape sequences:".encode("utf-8")) + os.write(1, "\n\r ~~ Send the escape character by typing it twice".encode("utf-8")) + os.write(1, "\n\r ~. Terminate session and exit immediately".encode("utf-8")) + os.write(1, "\n\r ~L Toggle line-interactive mode".encode("utf-8")) + os.write(1, "\n\r ~? Display this quick reference\n\r".encode("utf-8")) + os.write(1, "\n\r(Escape sequences are only recognized immediately after newline)\n\r".encode("utf-8")) + return None + elif b == ".": + _link.teardown() + return None + elif b == "L": + line_mode = not line_mode + if line_mode: + os.write(1, "\n\rLine-interactive mode enabled\n\r".encode("utf-8")) + else: + os.write(1, "\n\rLine-interactive mode disabled\n\r".encode("utf-8")) + return None + + return b + + stdin_eof = False + def stdin(): + nonlocal stdin_eof, pre_esc, esc, line_mode + nonlocal line_flush, blind_write_count + try: + in_data = process.tty_read(sys.stdin.fileno()) + if in_data is not None: + data = bytearray() + for b in bytes(in_data): + c = chr(b) + if c == "\r": + pre_esc = True + line_flush = True + data.append(b) + elif line_mode and c in flush_chars: + pre_esc = False + line_flush = True + data.append(b) + elif line_mode and (c == "\b" or c == "\x7f"): + pre_esc = False + if len(line_buffer)>0: + line_buffer.pop(-1) + blind_write_count -= 1 + os.write(1, "\b \b".encode("utf-8")) + elif pre_esc == True and c == "~": + pre_esc = False + esc = True + elif esc == True: + ret = handle_escape(c) + if ret != None: + if ret != "~": + data.append(ord("~")) + data.append(ord(ret)) + esc = False + else: + pre_esc = False + data.append(b) + + if not line_mode: + data_buffer.extend(data) + else: + line_buffer.extend(data) + if line_flush: + data_buffer.extend(line_buffer) + line_buffer.clear() + os.write(1, ("\b \b"*blind_write_count).encode("utf-8")) + line_flush = False + blind_write_count = 0 + else: + os.write(1, data) + blind_write_count += len(data) + + except EOFError: + if os.isatty(0): + data_buffer.extend(process.CTRL_D) + stdin_eof = True + process.tty_unset_reader_callbacks(sys.stdin.fileno()) + + process.tty_add_reader_callback(sys.stdin.fileno(), stdin) + + tcattr = None + rows, cols, hpix, vpix = (None, None, None, None) + try: + tcattr = termios.tcgetattr(0) + rows, cols, hpix, vpix = process.tty_get_winsize(0) + except: + try: + tcattr = termios.tcgetattr(1) + rows, cols, hpix, vpix = process.tty_get_winsize(1) + except: + try: + tcattr = termios.tcgetattr(2) + rows, cols, hpix, vpix = process.tty_get_winsize(2) + except: + pass + + await _spin(lambda: channel.is_ready_to_send(), "Waiting for channel...", 1, quietness > 0) + channel.send(protocol.ExecuteCommandMesssage(cmdline=command, + pipe_stdin=not os.isatty(0), + pipe_stdout=not os.isatty(1), + pipe_stderr=not os.isatty(2), + tcflags=tcattr, + term=os.environ.get("TERM", None), + rows=rows, + cols=cols, + hpix=hpix, + vpix=vpix)) + + loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler) + _finished = asyncio.Event() + loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, loop)) + loop.add_signal_handler(signal.SIGTERM, functools.partial(_sigint_handler, signal.SIGTERM, loop)) + mdu = _link.MDU - 16 + sent_eof = False + last_winch = time.time() + sleeper = helpers.SleepRate(0.01) + processed = False + while not await _check_finished() and state in [InitiatorState.IS_RUNNING]: + try: + try: + message = _pq.get(timeout=sleeper.next_sleep_time() if not processed else 0.0005) + await _handle_error(message) + processed = True + if isinstance(message, protocol.StreamDataMessage): + if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDOUT: + if message.data and len(message.data) > 0: + ttyRestorer.raw() + RNS.log(f"stdout: {message.data}", RNS.LOG_DEBUG) + os.write(1, message.data) + sys.stdout.flush() + if message.eof: + os.close(1) + if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDERR: + if message.data and len(message.data) > 0: + ttyRestorer.raw() + RNS.log(f"stdout: {message.data}", RNS.LOG_DEBUG) + os.write(2, message.data) + sys.stderr.flush() + if message.eof: + os.close(2) + elif isinstance(message, protocol.CommandExitedMessage): + RNS.log(f"received return code {message.return_code}, exiting", RNS.LOG_DEBUG) + return message.return_code + elif isinstance(message, protocol.ErrorMessage): + RNS.log(f"Remote error: {message.data}", RNS.LOG_ERROR) + if message.fatal: + _link.teardown() + return 200 + + except queue.Empty: + processed = False + + if channel.is_ready_to_send(): + def compress_adaptive(buf: bytes): + comp_tries = RNS.RawChannelWriter.COMPRESSION_TRIES + comp_try = 1 + comp_success = False + + chunk_len = len(buf) + if chunk_len > RNS.RawChannelWriter.MAX_CHUNK_LEN: + chunk_len = RNS.RawChannelWriter.MAX_CHUNK_LEN + chunk_segment = None + + chunk_segment = None + max_data_len = channel.mdu - protocol.StreamDataMessage.OVERHEAD + while chunk_len > 32 and comp_try < comp_tries: + chunk_segment_length = int(chunk_len/comp_try) + compressed_chunk = bz2.compress(buf[:chunk_segment_length]) + compressed_length = len(compressed_chunk) + if compressed_length < max_data_len and compressed_length < chunk_segment_length: + comp_success = True + break + else: + comp_try += 1 + + if comp_success: + diff = max_data_len - len(compressed_chunk) + chunk = compressed_chunk + processed_length = chunk_segment_length + else: + chunk = bytes(buf[:max_data_len]) + processed_length = len(chunk) + + return comp_success, processed_length, chunk + + comp_success, processed_length, chunk = compress_adaptive(data_buffer) + stdin = chunk + data_buffer = data_buffer[processed_length:] + eof = not sent_eof and stdin_eof and len(stdin) == 0 + if len(stdin) > 0 or eof: + channel.send(protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, stdin, eof, comp_success)) + sent_eof = eof + processed = True + + # send window change, but rate limited + if winch and time.time() - last_winch > _link.rtt * 25: + last_winch = time.time() + winch = False + with contextlib.suppress(Exception): + r, c, h, v = process.tty_get_winsize(0) + channel.send(protocol.WindowSizeMessage(r, c, h, v)) + processed = True + except RemoteExecutionError as e: + print(e.msg) + return 255 + except Exception as ex: + print(f"Client exception: {ex}") + if _link and _link.status != RNS.Link.CLOSED: + _link.teardown() + return 127 + + RNS.log("Main loop done", RNS.LOG_DEBUG) + return 0 \ No newline at end of file diff --git a/RNS/Utilities/rnsh/listener.py b/RNS/Utilities/rnsh/listener.py new file mode 100644 index 0000000..fa89741 --- /dev/null +++ b/RNS/Utilities/rnsh/listener.py @@ -0,0 +1,219 @@ +#!/usr/bin/env python3 + +# MIT License +# +# Copyright (c) 2016-2022 Mark Qvist / unsigned.io +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations + +import asyncio +import os +import queue +import shlex +import signal +import sys +import termios +import threading +import time +import tty +from typing import Callable, TypeVar +import RNS +import rnsh.exception as exception +import rnsh.process as process +import rnsh.retry as retry +import rnsh.session as session +import re +import contextlib +import rnsh.args +import pwd +import rnsh.protocol as protocol +import rnsh.helpers as helpers +import rnsh.rnsh + + +_identity = None +_reticulum = None +_allow_all = False +_allowed_file = None +_allowed_identity_hashes = [] +_allowed_file_identity_hashes = [] +_cmd: [str] | None = None +DATA_AVAIL_MSG = "data available" +_finished: asyncio.Event = None +_retry_timer: retry.RetryThread | None = None +_destination: RNS.Destination | None = None +_loop: asyncio.AbstractEventLoop | None = None +_no_remote_command = True +_remote_cmd_as_args = False + + +async def _check_finished(timeout: float = 0): + return await process.event_wait(_finished, timeout=timeout) + + +def _sigint_handler(sig, loop): + global _finished + RNS.log(f"Signal: {signal.Signals(sig).name}", RNS.LOG_DEBUG) + if _finished is not None: _finished.set() + else: raise KeyboardInterrupt() + +def _reload_allowed_file(): + global _allowed_file, _allowed_file_identity_hashes + if _allowed_file != None: + try: + with open(_allowed_file, "r") as file: + dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2 + added = 0 + line = 0 + _allowed_file_identity_hashes = [] + for allow in file.read().replace("\r", "").split("\n"): + line += 1 + if len(allow) == dest_len: + try: + destination_hash = bytes.fromhex(allow) + _allowed_file_identity_hashes.append(destination_hash) + added += 1 + except Exception: + RNS.log(f"Discarded invalid Identity hash in {_allowed_file} at line {line}", RNS.LOG_DEBUG) + + ms = "y" if added == 1 else "ies" + RNS.log(f"Loaded {added} allowed identit{ms} from "+str(_allowed_file), RNS.LOG_DEBUG) + + except Exception as e: RNS.log(f"Error while reloading allowed indetities file: {e}", RNS.LOG_ERROR) + +def compute_target_rns_loglevel(verbosity: int, quietness: int, base_level: int = RNS.LOG_INFO) -> int: + try: + target = int(base_level) + int(verbosity) - int(quietness) + if target < RNS.LOG_CRITICAL: target = RNS.LOG_CRITICAL + if target > RNS.LOG_DEBUG: target = RNS.LOG_DEBUG + return target + + except Exception: return base_level + +async def listen(configdir, rnsconfigdir, command, identitypath=None, service_name=None, verbosity=0, quietness=0, allowed=None, + allowed_file=None, disable_auth=None, announce_period=900, no_remote_command=True, remote_cmd_as_args=False, + loop: asyncio.AbstractEventLoop = None): + global _identity, _allow_all, _allowed_identity_hashes, _allowed_file, _allowed_file_identity_hashes + global _reticulum, _cmd, _destination, _no_remote_command, _remote_cmd_as_args, _finished + + if not loop: loop = asyncio.get_running_loop() + if service_name is None or len(service_name) == 0: + service_name = "default" + + RNS.log(f"Using service name {service_name}", RNS.LOG_INFO) + + # More -v should increase verbosity (higher RNS.loglevel); -q should decrease it + targetloglevel = compute_target_rns_loglevel(verbosity, quietness, RNS.LOG_INFO) + _reticulum = RNS.Reticulum(configdir=rnsconfigdir, loglevel=targetloglevel) + _identity = rnsh.rnsh.prepare_identity(identitypath, service_name) + _destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, rnsh.rnsh.APP_NAME) + + RNS.log(f"rnsh listening for commands on {RNS.prettyhexrep(_destination.hash)}", RNS.LOG_NOTICE) + + _cmd = command + if _cmd is None or len(_cmd) == 0: + shell = None + try: shell = pwd.getpwuid(os.getuid()).pw_shell + except Exception as e: RNS.log(f"Error looking up shell: {e}", RNS.LOG_ERROR) + RNS.log(f"Using {shell} for default command.", RNS.LOG_INFO) + + # Ensure a sane shell default. Fall back to /bin/sh if lookup fails. + if not shell or len(shell) == 0: shell = "/bin/sh" + _cmd = [shell] + + else: RNS.log(f"Using command {shlex.join(_cmd)}", RNS.LOG_INFO) + + _no_remote_command = no_remote_command + session.ListenerSession.allow_remote_command = not no_remote_command + _remote_cmd_as_args = remote_cmd_as_args + if (_cmd is None or len(_cmd) == 0 or _cmd[0] is None or len(_cmd[0]) == 0) \ + and (_no_remote_command or _remote_cmd_as_args): + raise Exception(f"Unable to look up shell for {os.getlogin}, cannot proceed with -A or -C and no .") + + session.ListenerSession.default_command = _cmd + session.ListenerSession.remote_cmd_as_args = _remote_cmd_as_args + + if disable_auth: + _allow_all = True + session.ListenerSession.allow_all = True + else: + if allowed_file is not None: + _allowed_file = allowed_file + _reload_allowed_file() + + if allowed is not None: + for a in allowed: + try: + dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2 + if len(a) != dest_len: + raise ValueError( + "Allowed destination length is invalid, must be {hex} hexadecimal " + + "characters ({byte} bytes).".format( + hex=dest_len, byte=dest_len // 2)) + try: + destination_hash = bytes.fromhex(a) + _allowed_identity_hashes.append(destination_hash) + session.ListenerSession.allowed_identity_hashes.append(destination_hash) + except Exception: + raise ValueError("Invalid destination entered. Check your input.") + + except Exception as e: + RNS.log(f"Unhandled error: {e}", RNS.LOG_ERROR) + RNS.trace_exception(e) + exit(1) + + if (len(_allowed_identity_hashes) < 1 and len(_allowed_file_identity_hashes) < 1) and not disable_auth: + RNS.log("Warning: No allowed identities configured, rnsh will not accept any connections!", RNS.LOG_WARNING) + + def link_established(lnk: RNS.Link): + _reload_allowed_file() + session.ListenerSession.allowed_file_identity_hashes = _allowed_file_identity_hashes + session.ListenerSession(session.RNSOutlet.get_outlet(lnk), lnk.get_channel(), loop) + _destination.set_link_established_callback(link_established) + + _finished = asyncio.Event() + signal.signal(signal.SIGINT, _sigint_handler) + + if announce_period is not None: _destination.announce() + + last_announce = time.time() + sleeper = helpers.SleepRate(0.01) + + try: + while not await _check_finished(): + if announce_period and 0 < announce_period < time.time() - last_announce: + last_announce = time.time() + _destination.announce() + if len(session.ListenerSession.sessions) > 0: + # no sleep if there's work to do + if not await session.ListenerSession.pump_all(): + await sleeper.sleep_async() + else: + await asyncio.sleep(0.25) + finally: + RNS.log("Shutting down", RNS.LOG_NOTICE) + await session.ListenerSession.terminate_all("Shutting down") + await asyncio.sleep(1) + links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links)) + for link in links_still_active: + if link.status not in [RNS.Link.CLOSED]: + link.teardown() + await asyncio.sleep(0.01) \ No newline at end of file diff --git a/RNS/Utilities/rnsh/loop.py b/RNS/Utilities/rnsh/loop.py new file mode 100644 index 0000000..19f83ba --- /dev/null +++ b/RNS/Utilities/rnsh/loop.py @@ -0,0 +1,12 @@ +import asyncio +import functools +from typing import Callable + +def sig_handler_sys_to_loop(handler: Callable[[int, any], None]) -> Callable[[int, asyncio.AbstractEventLoop], None]: + def wrapped(cb: Callable[[int, any], None], signal: int, loop: asyncio.AbstractEventLoop): cb(signal, None) + return functools.partial(wrapped, handler) + +def loop_set_signal(sig, handler: Callable[[int, asyncio.AbstractEventLoop], None], loop: asyncio.AbstractEventLoop = None): + if loop is None: loop = asyncio.get_running_loop() + loop.remove_signal_handler(sig) + loop.add_signal_handler(sig, functools.partial(handler, sig, loop)) \ No newline at end of file diff --git a/RNS/Utilities/rnsh/process.py b/RNS/Utilities/rnsh/process.py new file mode 100644 index 0000000..66e9976 --- /dev/null +++ b/RNS/Utilities/rnsh/process.py @@ -0,0 +1,773 @@ +# MIT License +# +# Copyright (c) 2023 Aaron Heise +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations +import asyncio +import contextlib +import copy +import errno +import fcntl +import functools +import os +import pty +import select +import signal +import struct +import sys +import termios +import threading +import tty +import types +import typing +import RNS + +import rnsh.exception as exception + +CTRL_C = "\x03".encode("utf-8") +CTRL_D = "\x04".encode("utf-8") + +def tty_add_reader_callback(fd: int, callback: callable, loop: asyncio.AbstractEventLoop = None): + """ + Add an async reader callback for a tty file descriptor. + + Example usage: + + def reader(): + data = tty_read(fd) + # do something with data + + tty_add_reader_callback(self._child_fd, reader, self._loop) + + :param fd: file descriptor + :param callback: callback function + :param loop: asyncio event loop to which the reader should be added. If None, use the currently-running loop. + """ + if loop is None: + loop = asyncio.get_running_loop() + loop.add_reader(fd, callback) + + +def tty_read(fd: int) -> bytes: + """ + Read available bytes from a tty file descriptor. When used in a callback added to a file descriptor using + tty_add_reader_callback(...), this function creates a solution for non-blocking reads from ttys. + :param fd: tty file descriptor + :return: bytes read + """ + if fd_is_closed(fd): + raise EOFError + + try: + run = True + result = bytearray() + while not fd_is_closed(fd): + ready, _, _ = select.select([fd], [], [], 0) + if len(ready) == 0: + break + for f in ready: + try: + data = os.read(f, 4096) + except OSError as e: + if e.errno != errno.EIO and e.errno != errno.EWOULDBLOCK: + raise + else: + if not data: # EOF + if data is not None and len(data) > 0: + result.extend(data) + return result + elif len(result) > 0: + return result + else: + raise EOFError + if data is not None and len(data) > 0: + result.extend(data) + return result + + except EOFError: raise + except Exception as e: RNS.log(f"TTY read error: {e}", RNS.LOG_ERROR) + + +def tty_read_poll(fd: int) -> bytes: + """ + Read available bytes from a tty file descriptor. When used in a callback added to a file descriptor using + tty_add_reader_callback(...), this function creates a solution for non-blocking reads from ttys. + :param fd: tty file descriptor + :return: bytes read + """ + if fd_is_closed(fd): + raise EOFError + + result = bytearray() + try: + flags = fcntl.fcntl(fd, fcntl.F_GETFL) + fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK) + while True: + try: + data = os.read(fd, 4096) + if not data: + # EOF + if len(result) > 0: + return result + raise EOFError + result.extend(data) + # continue loop to drain + except OSError as e: + if e.errno in (errno.EWOULDBLOCK, errno.EAGAIN): + break + if e.errno == errno.EIO: + if len(result) > 0: + return result + raise EOFError + raise + except EOFError: raise + except Exception as e: RNS.log(f"TTY read error: {e}", RNS.LOG_ERROR) + + return result + + +def fd_is_closed(fd: int) -> bool: + """ + Check if file descriptor is closed + :param fd: file descriptor + :return: True if file descriptor is closed + """ + try: + fcntl.fcntl(fd, fcntl.F_GETFL) < 0 + except OSError as ose: + return ose.errno == errno.EBADF + + +def tty_unset_reader_callbacks(fd: int, loop: asyncio.AbstractEventLoop = None): + """ + Remove async reader callbacks for file descriptor. + :param fd: file descriptor + :param loop: asyncio event loop from which to remove callbacks + """ + with exception.permit(SystemExit): + if loop is None: + loop = asyncio.get_running_loop() + loop.remove_reader(fd) + + +def tty_get_winsize(fd: int) -> [int, int, int, int]: + """ + Ge the window size of a tty. + :param fd: file descriptor of tty + :return: (rows, cols, h_pixels, v_pixels) + """ + packed = fcntl.ioctl(fd, termios.TIOCGWINSZ, struct.pack('HHHH', 0, 0, 0, 0)) + rows, cols, h_pixels, v_pixels = struct.unpack('HHHH', packed) + return rows, cols, h_pixels, v_pixels + + +def tty_set_winsize(fd: int, rows: int, cols: int, h_pixels: int, v_pixels: int): + """ + Set the window size on a tty. + :param fd: file descriptor of tty + :param rows: number of visible rows + :param cols: number of visible columns + :param h_pixels: number of visible horizontal pixels + :param v_pixels: number of visible vertical pixels + """ + if fd < 0: + return + packed = struct.pack('HHHH', rows, cols, h_pixels, v_pixels) + fcntl.ioctl(fd, termios.TIOCSWINSZ, packed) + + +def process_exists(pid) -> bool: + """ + Check For the existence of a unix pid. + :param pid: process id to check + :return: True if process exists + """ + try: + os.kill(pid, 0) + except OSError: + return False + else: + return True + + +class TTYRestorer(contextlib.AbstractContextManager): + # Indexes of flags within the attrs array + ATTR_IDX_IFLAG = 0 + ATTR_IDX_OFLAG = 1 + ATTR_IDX_CFLAG = 2 + ATTR_IDX_LFLAG = 4 + ATTR_IDX_CC = 5 + + def __init__(self, fd: int, suppress_logs=False): + """ + Saves termios attributes for a tty for later restoration. + + The attributes are an array of values with the following meanings. + + tcflag_t c_iflag; /* input modes */ + tcflag_t c_oflag; /* output modes */ + tcflag_t c_cflag; /* control modes */ + tcflag_t c_lflag; /* local modes */ + cc_t c_cc[NCCS]; /* special characters */ + + :param fd: file descriptor of tty + """ + self._fd = fd + self._tattr = None + self._suppress_logs = suppress_logs + self._tattr = self.current_attr() + if not self._tattr and not self._suppress_logs: RNS.log(f"Could not get attrs for fd {fd}", RNS.LOG_DEBUG) + + def raw(self): + """ + Set raw mode on tty + """ + if self._fd is None: + return + with contextlib.suppress(termios.error): + tty.setraw(self._fd, termios.TCSANOW) + + def original_attr(self) -> [any]: + return copy.deepcopy(self._tattr) + + def current_attr(self) -> [any]: + """ + Get the current termios attributes for the wrapped fd. + :return: attribute array + """ + if self._fd is None: + return None + + with contextlib.suppress(termios.error): + return copy.deepcopy(termios.tcgetattr(self._fd)) + return None + + def set_attr(self, attr: [any], when: int = termios.TCSADRAIN): + """ + Set termios attributes + :param attr: attribute list to set + :param when: when attributes should be applied (termios.TCSANOW, termios.TCSADRAIN, termios.TCSAFLUSH) + """ + if not attr or self._fd is None: + return + + with contextlib.suppress(termios.error): + termios.tcsetattr(self._fd, when, attr) + + def isatty(self): + return os.isatty(self._fd) if self._fd is not None else None + + def restore(self): + """ + Restore termios settings to state captured in constructor. + """ + self.set_attr(self._tattr, termios.TCSADRAIN) + + def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, + __traceback: types.TracebackType) -> bool: + self.restore() + return False #__exc_type is not None and issubclass(__exc_type, termios.error) + + +def _task_from_event(evt: asyncio.Event, loop: asyncio.AbstractEventLoop = None): + if not loop: + loop = asyncio.get_running_loop() + + #TODO: this is hacky + async def wait(): + while not evt.is_set(): + await asyncio.sleep(0.1) + return True + + return loop.create_task(wait()) + + +class AggregateException(Exception): + def __init__(self, inner_exceptions: [Exception]): + super().__init__() + self.inner_exceptions = inner_exceptions + + def __str__(self): + return "Multiple exceptions encountered: \n\n" + "\n\n".join(map(lambda e: str(e), self.inner_exceptions)) + + +async def event_wait_any(evts: [asyncio.Event], timeout: float = None) -> (any, any): + tasks = list(map(lambda evt: (evt, _task_from_event(evt)), evts)) + try: + finished, unfinished = await asyncio.wait(map(lambda t: t[1], tasks), + timeout=timeout, + return_when=asyncio.FIRST_COMPLETED) + + if len(unfinished) > 0: + for task in unfinished: + task.cancel() + await asyncio.wait(unfinished) + + exceptions = [] + + for f in finished: + ex = f.exception() + if ex and not isinstance(ex, asyncio.CancelledError) and not isinstance(ex, TimeoutError): + exceptions.append(ex) + + if len(exceptions) > 0: + raise AggregateException(exceptions) + + return next(map(lambda t: next(map(lambda tt: tt[0], tasks)), finished), None) + finally: + unfinished = [] + for task in map(lambda t: t[1], tasks): + if task.done(): + if not task.cancelled(): + task.exception() + else: + task.cancel() + unfinished.append(task) + if len(unfinished) > 0: + await asyncio.wait(unfinished) + + +async def event_wait(evt: asyncio.Event, timeout: float) -> bool: + """ + Wait for event to be set, or timeout to expire. + :param evt: asyncio.Event to wait on + :param timeout: maximum number of seconds to wait. + :return: True if event was set, False if timeout expired + """ + await event_wait_any([evt], timeout=timeout) + return evt.is_set() + + +def _launch_child(cmd_line: list[str], env: dict[str, str], stdin_is_pipe: bool, stdout_is_pipe: bool, + stderr_is_pipe: bool) -> tuple[int, int, int, int]: + # Set up PTY and/or pipes + child_fd = parent_fd = None + if not (stdin_is_pipe and stdout_is_pipe and stderr_is_pipe): + parent_fd, child_fd = pty.openpty() + child_stdin, parent_stdin = (os.pipe() if stdin_is_pipe else (child_fd, parent_fd)) + parent_stdout, child_stdout = (os.pipe() if stdout_is_pipe else (parent_fd, child_fd)) + parent_stderr, child_stderr = (os.pipe() if stderr_is_pipe else (parent_fd, child_fd)) + + # Fork + pid = os.fork() + + if pid == 0: + try: + # We are in the child process, so close all open sockets and pipes except for the PTY and/or pipes + max_fd = os.sysconf("SC_OPEN_MAX") + for fd in range(3, max_fd): + if fd not in (child_stdin, child_stdout, child_stderr): + try: + os.close(fd) + except OSError: + pass + + # Set up PTY and/or pipes + os.dup2(child_stdin, 0) + os.dup2(child_stdout, 1) + os.dup2(child_stderr, 2) + # Make PTY controlling if necessary so that CTRL_C/CTRL_D behave as expected + if child_fd is not None: + os.setsid() + try: + tty_fd = 0 if not stdin_is_pipe else (1 if not stdout_is_pipe else 2) + # Set controlling TTY for this session + fcntl.ioctl(tty_fd, termios.TIOCSCTTY, 0) + except Exception: + pass + # Ensure the child is the foreground process group for the TTY + try: + os.setpgid(0, 0) + pgid = os.getpgrp() + import struct as _struct + fcntl.ioctl(tty_fd, termios.TIOCSPGRP, _struct.pack('i', pgid)) + except Exception: + pass + # Ensure canonical input with signals and local echo enabled + try: + tty_fd = 0 if not stdin_is_pipe else (1 if not stdout_is_pipe else 2) + attrs = termios.tcgetattr(tty_fd) + lflag = attrs[3] + lflag |= termios.ICANON | termios.ISIG | termios.ECHO + attrs[3] = lflag + termios.tcsetattr(tty_fd, termios.TCSANOW, attrs) + except Exception: + pass + + # Execute the command + os.execvpe(cmd_line[0], cmd_line, env) + except Exception as err: + exc_type, exc_obj, exc_tb = sys.exc_info() + fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1] + print(f"Unable to start {cmd_line[0]}: {err} ({fname}:{exc_tb.tb_lineno})") + sys.stdout.flush() + # don't let any other modules get in our way, do an immediate silent exit. + os._exit(255) + + else: + # We are in the parent process, so close the child-side of the PTY and/or pipes + if child_fd is not None: + os.close(child_fd) + if child_stdin != child_fd: + os.close(child_stdin) + if child_stdout != child_fd: + os.close(child_stdout) + if child_stderr != child_fd: + os.close(child_stderr) + # # Close the write end of the pipe if a pipe is used for standard input + # if not stdin_is_pipe: + # os.close(parent_stdin) + # Return the child PID and the file descriptors for the PTY and/or pipes + return pid, parent_stdin, parent_stdout, parent_stderr + + +class CallbackSubprocess: + # time between checks of child process + PROCESS_POLL_TIME: float = 0.1 + # Close pipes soon after process exit to avoid scheduling on closed event loops + PROCESS_PIPE_TIME: int = 1 + + def __init__(self, argv: [str], env: dict, loop: asyncio.AbstractEventLoop, stdout_callback: callable, + stderr_callback: callable, terminated_callback: callable, stdin_is_pipe: bool, stdout_is_pipe: bool, + stderr_is_pipe: bool): + """ + Fork a child process and generate callbacks with output from the process. + :param argv: the command line, tokenized. The first element must be the absolute path to an executable file. + :param env: environment variables to override + :param loop: the asyncio event loop to use + :param stdout_callback: callback for data, e.g. def callback(data:bytes) -> None + :param terminated_callback: callback for termination/return code, e.g. def callback(return_code:int) -> None + """ + assert loop is not None, "loop should not be None" + assert stdout_callback is not None, "stdout_callback should not be None" + assert terminated_callback is not None, "terminated_callback should not be None" + + self._command: [str] = argv + self._env = env or {} + self._loop = loop + self._stdout_cb = stdout_callback + self._stderr_cb = stderr_callback + self._terminated_cb = terminated_callback + self._pid: int = None + self._child_stdin: int = None + self._child_stdout: int = None + self._child_stderr: int = None + self._return_code: int = None + self._stdout_eof: bool = False + self._stderr_eof: bool = False + self._stdin_is_pipe = stdin_is_pipe + self._stdout_is_pipe = stdout_is_pipe + self._stderr_is_pipe = stderr_is_pipe + self._at_line_start: bool = True + self._tty_line_buffer: bytearray = bytearray() + + def _ensure_pipes_closed(self): + stdin = self._child_stdin + stdout = self._child_stdout + stderr = self._child_stderr + fds = set(filter(lambda x: x is not None, list({stdin, stdout, stderr}))) + RNS.log(f"Queuing close of pipes for ended process (fds: {fds})", RNS.LOG_DEBUG) + + def ensure_pipes_closed_inner(): + RNS.log(f"Ensuring pipes are closed (fds: {fds})", RNS.LOG_DEBUG) + for fd in fds: + RNS.log(f"Closing fd {fd}", RNS.LOG_DEBUG) + with contextlib.suppress(OSError): tty_unset_reader_callbacks(fd) + with contextlib.suppress(OSError): os.close(fd) + + self._child_stdin = None + self._child_stdout = None + self._child_stderr = None + + # Avoid scheduling on a closed loop + if self._loop.is_closed(): ensure_pipes_closed_inner() + else: self._loop.call_later(CallbackSubprocess.PROCESS_PIPE_TIME, ensure_pipes_closed_inner) + + def terminate(self, kill_delay: float = 1.0): + """ + Terminate child process if running + :param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL + """ + + RNS.log("terminate()", RNS.LOG_EXTREME) + if not self.running: return + + with exception.permit(SystemExit): os.kill(self._pid, signal.SIGTERM) + + def kill(): + if process_exists(self._pid): + RNS.log("kill()", RNS.LOG_EXTREME) + with exception.permit(SystemExit): + os.kill(self._pid, signal.SIGHUP) + os.kill(self._pid, signal.SIGKILL) + + self._loop.call_later(kill_delay, kill) + + def wait(): + RNS.log("wait()", RNS.LOG_EXTREME) + with contextlib.suppress(OSError): os.waitpid(self._pid, 0) + self._ensure_pipes_closed() + RNS.log("wait() finish", RNS.LOG_EXTREME) + + threading.Thread(target=wait, daemon=True).start() + + def close_stdin(self): + with contextlib.suppress(Exception): + os.close(self._child_stdin) + # Encourage prompt shutdown if child lingers after stdin close + def _ensure_terminate(): + if self.running: + self.terminate(kill_delay=0.2) + if not self._loop.is_closed(): + self._loop.call_later(0.05, _ensure_terminate) + + @property + def started(self) -> bool: + """ + :return: True if child process has been started + """ + return self._pid is not None + + @property + def running(self) -> bool: + """ + :return: True if child process is still running + """ + return self._pid is not None and process_exists(self._pid) + + def write(self, data: bytes): + """ + Write bytes to the stdin of the child process. + :param data: bytes to write + """ + + os.write(self._child_stdin, data) + + # TODO: Check what this is actually supposed to solve. + # + # For pipe-in + TTY-out, echo should be visible immediately + if self._stdin_is_pipe and not self._stdout_is_pipe and self._stdout_cb is not None and data not in (CTRL_C, CTRL_D): + try: self._stdout_cb(data) + except Exception: pass + + def set_winsize(self, r: int, c: int, h: int, v: int): + """ + Set the window size on the tty of the child process. + :param r: rows visible + :param c: columns visible + :param h: horizontal pixels visible + :param v: vertical pixels visible + :return: + """ + RNS.log(f"set_winsize({r},{c},{h},{v}", RNS.LOG_DEBUG) + tty_set_winsize(self._child_stdout, r, c, h, v) + + def copy_winsize(self, fromfd: int): + """ + Copy window size from one tty to another. + :param fromfd: source tty file descriptor + """ + r, c, h, v = tty_get_winsize(fromfd) + self.set_winsize(r, c, h, v) + + def tcsetattr(self, when: int, attr: list[any]): # actual type is list[int | list[int | bytes]] + """ + Set tty attributes. + :param when: when to apply change: termios.TCSANOW or termios.TCSADRAIN or termios.TCSAFLUSH + :param attr: attributes to set + """ + termios.tcsetattr(self._child_stdin, when, attr) + + def tcgetattr(self) -> list[any]: # actual type is list[int | list[int | bytes]] + """ + Get tty attributes. + :return: tty attributes value + """ + return termios.tcgetattr(self._child_stdout) + + def ttysetraw(self): + tty.setraw(self._child_stdout, termios.TCSADRAIN) + + def start(self): + """ + Start the child process. + """ + RNS.log("start()", RNS.LOG_EXTREME) + + # # Using the parent environment seems to do some weird stuff, at least on macOS + # parentenv = os.environ.copy() + # env = {"HOME": parentenv["HOME"], + # "PATH": parentenv["PATH"], + # "TERM": self._term if self._term is not None else parentenv.get("TERM", "xterm"), + # "LANG": parentenv.get("LANG"), + # "SHELL": self._command[0]} + + env = os.environ.copy() + for key in self._env: + env[key] = self._env[key] + + program = self._command[0] + assert isinstance(program, str) + + # match = re.search("^/bin/(.*sh)$", program) + # if match: + # self._command[0] = "-" + match.group(1) + # env["SHELL"] = program + # self._log.debug(f"set login shell {self._command}") + + self._pid, \ + self._child_stdin, \ + self._child_stdout, \ + self._child_stderr = _launch_child(self._command, env, self._stdin_is_pipe, self._stdout_is_pipe, + self._stderr_is_pipe) + RNS.log(f"Started pid {self.pid}, fds: {self._child_stdin}, {self._child_stdout}, {self._child_stderr}", RNS.LOG_DEBUG) + + def poll(): + try: + pid, self._return_code = os.waitpid(self._pid, os.WNOHANG) + if self._return_code is not None: + self._return_code = self._return_code & 0xff + if self._return_code is not None and not process_exists(self._pid): + RNS.log(f"polled return code {self._return_code}", RNS.LOG_DEBUG) + self._terminated_cb(self._return_code) + if self.running: + self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll) + else: + self._ensure_pipes_closed() + except Exception as e: + if not hasattr(e, "errno") or e.errno != errno.ECHILD: + RNS.log(f"Error in process poll: {e}", RNS.LOG_DEBUG) + + self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll) + + def stdout(): + try: + with exception.permit(SystemExit): + data = tty_read_poll(self._child_stdout) + if data is not None and len(data) > 0: + self._stdout_cb(data) + # Opportunistically drain shortly after to coalesce immediate follow-up output + if not self._loop.is_closed(): + self._loop.call_later(0.01, stdout) + except EOFError: + self._stdout_eof = True + tty_unset_reader_callbacks(self._child_stdout) + self._stdout_cb(bytearray()) + + def stderr(): + try: + with exception.permit(SystemExit): + data = tty_read_poll(self._child_stderr) + if data is not None and len(data) > 0: + self._stderr_cb(data) + if not self._loop.is_closed(): + self._loop.call_later(0.01, stderr) + except EOFError: + self._stderr_eof = True + tty_unset_reader_callbacks(self._child_stderr) + self._stderr_cb(bytearray()) + + tty_add_reader_callback(self._child_stdout, stdout, self._loop) + if self._child_stderr != self._child_stdout: + tty_add_reader_callback(self._child_stderr, stderr, self._loop) + + @property + def stdout_eof(self): + return self._stdout_eof or not self.running + + @property + def stderr_eof(self): + return self._stderr_eof or not self.running + + + @property + def return_code(self) -> int: + return self._return_code + + @property + def pid(self) -> int: + return self._pid + + +async def main(): + """ + A test driver for the CallbackProcess class. + python ./process.py /bin/zsh --login + """ + + if len(sys.argv) <= 1: + print(f"Usage: {sys.argv} [child_arg ...]") + exit(1) + + loop = asyncio.get_event_loop() + # asyncio.set_event_loop(loop) + retcode = loop.create_future() + + def stdout(data: bytes): os.write(sys.stdout.fileno(), data) + + def terminated(rc: int): retcode.set_result(rc) + + process = CallbackSubprocess(argv=sys.argv[1:], + env={"TERM": os.environ.get("TERM", "xterm")}, + loop=loop, + stdout_callback=stdout, + terminated_callback=terminated) + + def sigint_handler(sig, frame): + if process is None or process.started and not process.running: + raise KeyboardInterrupt + elif process.running: + process.write("\x03".encode("utf-8")) + + def sigwinch_handler(sig, frame): + process.copy_winsize(sys.stdin.fileno()) + + signal.signal(signal.SIGINT, sigint_handler) + signal.signal(signal.SIGWINCH, sigwinch_handler) + + def stdin(): + try: + data = tty_read(sys.stdin.fileno()) + if data is not None: + process.write(data) + + except EOFError: + tty_unset_reader_callbacks(sys.stdin.fileno()) + process.write(CTRL_D) + + tty_add_reader_callback(sys.stdin.fileno(), stdin) + process.start() + # call_soon called it too soon, not sure why. + loop.call_later(0.001, functools.partial(process.copy_winsize, sys.stdin.fileno())) + + val = await retcode + RNS.log(f"Got return code {val}", RNS.LOG_DEBUG) + return val + + +if __name__ == "__main__": + tr = TTYRestorer(sys.stdin.fileno()) + try: + tr.raw() + asyncio.run(main()) + finally: + tty_unset_reader_callbacks(sys.stdin.fileno()) + tr.restore() diff --git a/RNS/Utilities/rnsh/protocol.py b/RNS/Utilities/rnsh/protocol.py new file mode 100644 index 0000000..664a833 --- /dev/null +++ b/RNS/Utilities/rnsh/protocol.py @@ -0,0 +1,115 @@ +from __future__ import annotations + +import RNS +from RNS.vendor import umsgpack +from RNS.Buffer import StreamDataMessage as RNSStreamDataMessage +import rnsh.retry +import abc +import contextlib +import struct +from abc import ABC, abstractmethod + +MSG_MAGIC = 0xac +PROTOCOL_VERSION = 1 + +def _make_MSGTYPE(val: int): + return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff) + + +class NoopMessage(RNS.MessageBase): + MSGTYPE = _make_MSGTYPE(0) + def pack(self) -> bytes: return bytes() + def unpack(self, raw): pass + + +class WindowSizeMessage(RNS.MessageBase): + MSGTYPE = _make_MSGTYPE(2) + + def __init__(self, rows: int = None, cols: int = None, hpix: int = None, vpix: int = None): + super().__init__() + self.rows = rows + self.cols = cols + self.hpix = hpix + self.vpix = vpix + + def pack(self) -> bytes: return umsgpack.packb((self.rows, self.cols, self.hpix, self.vpix)) + def unpack(self, raw): self.rows, self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw) + + +class ExecuteCommandMesssage(RNS.MessageBase): + MSGTYPE = _make_MSGTYPE(3) + + def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False, + pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None, rows: int = None, + cols: int = None, hpix: int = None, vpix: int = None): + + super().__init__() + self.cmdline = cmdline + self.pipe_stdin = pipe_stdin + self.pipe_stdout = pipe_stdout + self.pipe_stderr = pipe_stderr + self.tcflags = tcflags + self.term = term + self.rows = rows + self.cols = cols + self.hpix = hpix + self.vpix = vpix + + def pack(self) -> bytes: + return umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, + self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix)) + + def unpack(self, raw): + self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \ + self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw) + + +# Create a version of RNS.Buffer.StreamDataMessage that we control +class StreamDataMessage(RNSStreamDataMessage): + MSGTYPE = _make_MSGTYPE(4) + STREAM_ID_STDIN = 0 + STREAM_ID_STDOUT = 1 + STREAM_ID_STDERR = 2 + + +class VersionInfoMessage(RNS.MessageBase): + MSGTYPE = _make_MSGTYPE(5) + + def __init__(self, sw_version: str = None): + super().__init__() + self.sw_version = sw_version or rnsh.__version__ + self.protocol_version = PROTOCOL_VERSION + + def pack(self) -> bytes: return umsgpack.packb((self.sw_version, self.protocol_version)) + def unpack(self, raw): self.sw_version, self.protocol_version = umsgpack.unpackb(raw) + + +class ErrorMessage(RNS.MessageBase): + MSGTYPE = _make_MSGTYPE(6) + + def __init__(self, msg: str = None, fatal: bool = False, data: dict = None): + super().__init__() + self.msg = msg + self.fatal = fatal + self.data = data + + def pack(self) -> bytes: return umsgpack.packb((self.msg, self.fatal, self.data)) + def unpack(self, raw: bytes): self.msg, self.fatal, self.data = umsgpack.unpackb(raw) + + +class CommandExitedMessage(RNS.MessageBase): + MSGTYPE = _make_MSGTYPE(7) + + def __init__(self, return_code: int = None): + super().__init__() + self.return_code = return_code + + def pack(self) -> bytes: return umsgpack.packb(self.return_code) + def unpack(self, raw: bytes): self.return_code = umsgpack.unpackb(raw) + + +message_types = [NoopMessage, VersionInfoMessage, WindowSizeMessage, ExecuteCommandMesssage, StreamDataMessage, + CommandExitedMessage, ErrorMessage] + +def register_message_types(channel: RNS.Channel.Channel): + for message_type in message_types: channel.register_message_type(message_type) \ No newline at end of file diff --git a/RNS/Utilities/rnsh/retry.py b/RNS/Utilities/rnsh/retry.py new file mode 100644 index 0000000..b761a10 --- /dev/null +++ b/RNS/Utilities/rnsh/retry.py @@ -0,0 +1,189 @@ +# MIT License +# +# Copyright (c) 2023 Aaron Heise +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +import asyncio +import threading +import time +import rnsh.exception as exception +from typing import Callable +from contextlib import AbstractContextManager +import types +import typing + + +class RetryStatus: + def __init__(self, tag: any, try_limit: int, wait_delay: float, retry_callback: Callable[[any, int], any], + timeout_callback: Callable[[any, int], None], tries: int = 1): + + self.tag = tag + self.try_limit = try_limit + self.tries = tries + self.wait_delay = wait_delay + self.retry_callback = retry_callback + self.timeout_callback = timeout_callback + self.try_time = time.time() + self.completed = False + + @property + def ready(self): + ready = time.time() > self.try_time + self.wait_delay + RNS.log(f"ready check {self.tag} try_time {self.try_time} wait_delay {self.wait_delay} " + + f"next_try {self.try_time + self.wait_delay} now {time.time()} " + + f"exceeded {time.time() - self.try_time - self.wait_delay} ready {ready}", RNS.LOG_DEBUG) + return ready + + @property + def timed_out(self): + return self.ready and self.tries >= self.try_limit + + def timeout(self): + self.completed = True + self.timeout_callback(self.tag, self.tries) + + def retry(self) -> any: + self.tries = self.tries + 1 + self.try_time = time.time() + return self.retry_callback(self.tag, self.tries) + + +class RetryThread(AbstractContextManager): + def __init__(self, loop_period: float = 0.25, name: str = "retry thread"): + self._loop_period = loop_period + self._statuses: list[RetryStatus] = [] + self._tag_counter = 0 + self._lock = threading.RLock() + self._run = True + self._finished: asyncio.Future = None + self._thread = threading.Thread(name=name, target=self._thread_run, daemon=True) + self._thread.start() + + def is_alive(self): + return self._thread.is_alive() + + def close(self, loop: asyncio.AbstractEventLoop = None) -> asyncio.Future: + RNS.log("Stopping timer thread", RNS.LOG_DEBUG) + if loop is None: + self._run = False + self._thread.join() + return None + else: + self._finished = loop.create_future() + return self._finished + + def wait(self, timeout: float = None): + if timeout: + timeout = timeout + time.time() + + while timeout is None or time.time() < timeout: + with self._lock: + task_count = len(self._statuses) + if task_count == 0: + return + time.sleep(0.1) + + + def _thread_run(self): + while self._run and self._finished is None: + time.sleep(self._loop_period) + ready: list[RetryStatus] = [] + prune: list[RetryStatus] = [] + with self._lock: ready.extend(list(filter(lambda s: s.ready, self._statuses))) + + for retry in ready: + try: + if not retry.completed: + if retry.timed_out: + RNS.log(f"Timed out {retry.tag} after {retry.try_limit} tries", RNS.LOG_DEBUG) + retry.timeout() + prune.append(retry) + elif retry.ready: + RNS.log(f"Retrying {retry.tag}, try {retry.tries + 1}/{retry.try_limit}", RNS.LOG_DEBUG) + should_continue = retry.retry() + if not should_continue: self.complete(retry.tag) + + except Exception as e: + RNS.log(f"Error processing retry id {retry.tag}: {e}", RNS.LOG_ERROR) + prune.append(retry) + + with self._lock: + for retry in prune: + RNS.log(f"pruned retry {retry.tag}, retry count {retry.tries}/{retry.try_limit}", RNS.LOG_DEBUG) + with exception.permit(SystemExit): self._statuses.remove(retry) + + if self._finished is not None: self._finished.set_result(None) + + def _get_next_tag(self): + self._tag_counter += 1 + return self._tag_counter + + def has_tag(self, tag: any) -> bool: + with self._lock: return next(filter(lambda s: s.tag == tag, self._statuses), None) is not None + + def begin(self, try_limit: int, wait_delay: float, try_callback: Callable[[any, int], any], + timeout_callback: Callable[[any, int], None]) -> any: + + RNS.log(f"Running first try", RNS.LOG_DEBUG) + tag = try_callback(None, 1) + RNS.log(f"First try got id {tag}", RNS.LOG_DEBUG) + + if not tag: + RNS.log(f"Callback returned None/False/0, considering complete.", RNS.LOG_DEBUG) + return None + + with self._lock: + if tag is None: tag = self._get_next_tag() + self.complete(tag) + + self._statuses.append(RetryStatus(tag=tag, + tries=1, + try_limit=try_limit, + wait_delay=wait_delay, + retry_callback=try_callback, + timeout_callback=timeout_callback)) + + RNS.log(f"Added retry timer for {tag}", RNS.LOG_DEBUG) + return tag + + def complete(self, tag: any): + assert tag is not None + with self._lock: + status = next(filter(lambda l: l.tag == tag, self._statuses), None) + if status is not None: + status.completed = True + self._statuses.remove(status) + RNS.log(f"completed {tag}", RNS.LOG_DEBUG) + return + + RNS.log(f"status not found to complete {tag}", RNS.LOG_DEBUG) + + def complete_all(self): + with self._lock: + for status in self._statuses: + status.completed = True + RNS.log(f"completed {status.tag}", RNS.LOG_DEBUG) + + self._statuses.clear() + + def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException, + __traceback: types.TracebackType) -> bool: + self.close() + return False diff --git a/RNS/Utilities/rnsh/rnsh.py b/RNS/Utilities/rnsh/rnsh.py new file mode 100644 index 0000000..f3a3f79 --- /dev/null +++ b/RNS/Utilities/rnsh/rnsh.py @@ -0,0 +1,161 @@ +#!/usr/bin/env python3 + +# MIT License +# +# Copyright (c) 2016-2022 Mark Qvist / unsigned.io +# +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: +# +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. +# +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +from __future__ import annotations + +import asyncio +import base64 + +import re +import os +import sys + +import RNS +import rnsh.process as process +import rnsh.session as session +import rnsh.args +import rnsh.loop +import rnsh.listener as listener +import rnsh.initiator as initiator + +APP_NAME = "rnsh" +loop: asyncio.AbstractEventLoop | None = None + +def _sanitize_service_name(service_name:str) -> str: return re.sub(r'\W+', '', service_name) + +def prepare_identity(identity_path, service_name: str = None) -> tuple[RNS.Identity]: + service_name = _sanitize_service_name(service_name or "") + if identity_path is None: + identity_path = RNS.Reticulum.identitypath + "/" + APP_NAME + \ + (f".{service_name}" if service_name and len(service_name) > 0 else "") + + identity = None + if os.path.isfile(identity_path): + identity = RNS.Identity.from_file(identity_path) + + if identity is None: + RNS.log("No valid saved identity found, creating new...", RNS.LOG_INFO) + identity = RNS.Identity() + identity.to_file(identity_path) + + return identity + + +def print_identity(configdir, identitypath, service_name, include_destination: bool): + reticulum = RNS.Reticulum(configdir=configdir, loglevel=RNS.LOG_INFO) + if service_name and len(service_name) > 0: + print(f"Using service name \"{service_name}\"") + identity = prepare_identity(identitypath, service_name) + destination = RNS.Destination(identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME) + print("Identity : " + str(identity)) + if include_destination: + print("Listening on : " + RNS.prettyhexrep(destination.hash)) + + exit(0) + +verbose_set = False + +def ensure_config_directory(): + if os.path.isdir(os.path.expanduser("~/.config/rnsh")): return os.path.expanduser("~/.config/rnsh") + elif os.path.isdir(os.path.expanduser("~/.rnsh")): return os.path.expanduser("~/.rnsh") + else: + try: + os.makedirs(os.path.expanduser("~/.rnsh")) + return os.path.expanduser("~/.rnsh") + + except Exception as e: + RNS.log(f"Could not get or create rnsh configuration directory, aborting", RNS.LOG_CRITICAL) + os._exit(1) + + +async def _rnsh_cli_main(): + global verbose_set + args = rnsh.args.Args(sys.argv) + verbose_set = args.verbose > 0 + + configdir = ensure_config_directory() + + if args.print_identity: + print_identity(args.config, args.identity, args.service_name, args.listen) + return 0 + + if args.listen: + allowed_file = None + dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2 + if os.path.isfile(os.path.expanduser("~/.config/rnsh/allowed_identities")): + allowed_file = os.path.expanduser("~/.config/rnsh/allowed_identities") + elif os.path.isfile(os.path.expanduser("~/.rnsh/allowed_identities")): + allowed_file = os.path.expanduser("~/.rnsh/allowed_identities") + + await listener.listen(configdir=configdir, + rnsconfigdir=args.config, + command=args.command_line, + identitypath=args.identity, + service_name=args.service_name, + verbosity=args.verbose, + quietness=args.quiet, + allowed=args.allowed, + allowed_file=allowed_file, + disable_auth=args.no_auth, + announce_period=args.announce, + no_remote_command=args.no_remote_cmd, + remote_cmd_as_args=args.remote_cmd_as_args) + return 0 + + if args.destination is not None: + return_code = await initiator.initiate(configdir=configdir, + rnsconfigdir=args.config, + identitypath=args.identity, + verbosity=args.verbose, + quietness=args.quiet, + noid=args.no_id, + destination=args.destination, + timeout=args.timeout, + command=args.command_line + ) + return return_code if args.mirror else 0 + else: + print("") + print(rnsh.args.usage) + print("") + return 1 + + +def main(): + global verbose_set + return_code = 1 + exc = None + try: return_code = asyncio.run(_rnsh_cli_main()) + except SystemExit: pass + except KeyboardInterrupt: pass + except Exception as ex: + print(f"Unhandled exception: {ex}") + exc = ex + + process.tty_unset_reader_callbacks(0) + if verbose_set and exc: raise exc + sys.exit(return_code if return_code is not None else 255) + + +if __name__ == "__main__": main() diff --git a/RNS/Utilities/rnsh/session.py b/RNS/Utilities/rnsh/session.py new file mode 100644 index 0000000..0c494a6 --- /dev/null +++ b/RNS/Utilities/rnsh/session.py @@ -0,0 +1,407 @@ +from __future__ import annotations +import contextlib +import functools +import rnsh.exception as exception +import asyncio +import rnsh.process as process +import rnsh.helpers as helpers +import rnsh.protocol as protocol +import enum +from typing import TypeVar, Generic, Callable, List +from abc import abstractmethod, ABC +from multiprocessing import Manager +import os +import bz2 +import RNS + +_TLink = TypeVar("_TLink") +_TIdentity = TypeVar("_TIdentity") + +class SEType(enum.IntEnum): + SE_LINK_CLOSED = 0 + +class SessionException(Exception): + def __init__(self, setype: SEType, msg: str, *args): + super().__init__(msg, args) + self.type = setype + +class LSState(enum.IntEnum): + LSSTATE_WAIT_IDENT = 1 + LSSTATE_WAIT_VERS = 2 + LSSTATE_WAIT_CMD = 3 + LSSTATE_RUNNING = 4 + LSSTATE_ERROR = 5 + LSSTATE_TEARDOWN = 6 + + +class LSOutletBase(ABC): + @abstractmethod + def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]): raise NotImplemented() + + @abstractmethod + def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]): raise NotImplemented() + + @abstractmethod + def unset_link_closed_callback(self): raise NotImplemented() + + @property + @abstractmethod + def rtt(self): raise NotImplemented() + + @abstractmethod + def teardown(self): raise NotImplemented() + + +class ListenerSession: + sessions: List[ListenerSession] = [] + allowed_identity_hashes: [any] = [] + allowed_file_identity_hashes: [any] = [] + allow_all: bool = False + allow_remote_command: bool = False + default_command: [str] = [] + remote_cmd_as_args = False + + def __init__(self, outlet: LSOutletBase, channel: RNS.Channel.Channel, loop: asyncio.AbstractEventLoop): + RNS.log(f"Session started for {outlet}", RNS.LOG_INFO) + self.outlet = outlet + self.channel = channel + self.outlet.set_initiator_identified_callback(self._initiator_identified) + self.outlet.set_link_closed_callback(self._link_closed) + self.loop = loop + self.state: LSState = None + self.remote_identity = None + self.term: str | None = None + self.stdin_is_pipe: bool = False + self.stdout_is_pipe: bool = False + self.stderr_is_pipe: bool = False + self.tcflags: [any] = None + self.cmdline: [str] = None + self.rows: int = 0 + self.cols: int = 0 + self.hpix: int = 0 + self.vpix: int = 0 + self.stdout_buf = bytearray() + self.stdout_eof_sent = False + self.stderr_buf = bytearray() + self.stderr_eof_sent = False + self.return_code: int | None = None + self.return_code_sent = False + self.process: process.CallbackSubprocess | None = None + + if self.allow_all: self._set_state(LSState.LSSTATE_WAIT_VERS) + else: self._set_state(LSState.LSSTATE_WAIT_IDENT) + + self.sessions.append(self) + protocol.register_message_types(self.channel) + self.channel.add_message_handler(self._handle_message) + + def _terminated(self, return_code: int): + self.return_code = return_code + + def _set_state(self, state: LSState, timeout_factor: float = 10.0): + timeout = max(self.outlet.rtt * timeout_factor, max(self.outlet.rtt * 2, 10)) if timeout_factor is not None else None + RNS.log(f"Set state: {state.name}, timeout {timeout}", RNS.LOG_DEBUG) + orig_state = self.state + self.state = state + if timeout_factor is not None: + self._call(functools.partial(self._check_protocol_timeout, lambda: self.state == orig_state, state.name), timeout) + + def _call(self, func: callable, delay: float = 0): + def call_inner(): + if delay == 0: func() + else: self.loop.call_later(delay, func) + + self.loop.call_soon_threadsafe(call_inner) + + def send(self, message: RNS.MessageBase): + self.channel.send(message) + + def _protocol_error(self, name: str): + self.terminate(f"Protocol error ({name})") + + def _protocol_timeout_error(self, name: str): + self.terminate(f"Protocol timeout error: {name}") + + def terminate(self, error: str = None): + with contextlib.suppress(Exception): + RNS.log("Terminating session" + (f": {error}" if error else ""), RNS.LOG_DEBUG) + if error and self.state != LSState.LSSTATE_TEARDOWN: + with contextlib.suppress(Exception): + self.send(protocol.ErrorMessage(error, True)) + + self.state = LSState.LSSTATE_ERROR + self._terminate_process() + self._call(self._prune, max(self.outlet.rtt * 3, process.CallbackSubprocess.PROCESS_PIPE_TIME+5)) + + def _prune(self): + self.state = LSState.LSSTATE_TEARDOWN + RNS.log("Pruning session", RNS.LOG_DEBUG) + with contextlib.suppress(ValueError): + self.sessions.remove(self) + with contextlib.suppress(Exception): + self.outlet.teardown() + + def _check_protocol_timeout(self, fail_condition: Callable[[], bool], name: str): + timeout = True + try: timeout = self.state != LSState.LSSTATE_TEARDOWN and fail_condition() + except Exception as e: RNS.log(f"Error in protocol timeout: {e}", RNS.LOG_ERROR) + if timeout: self._protocol_timeout_error(name) + + def _link_closed(self, outlet: LSOutletBase): + outlet.unset_link_closed_callback() + + if outlet != self.outlet: + RNS.log("Link closed received from incorrect outlet", RNS.LOG_DEBUG) + return + + RNS.log(f"link_closed {outlet}", RNS.LOG_DEBUG) + self.terminate() + + def _initiator_identified(self, outlet, identity): + if outlet != self.outlet: + RNS.log("Identity received from incorrect outlet", RNS.LOG_DEBUG) + return + + RNS.log(f"initiator_identified {identity} on link {outlet}", RNS.LOG_INFO) + if self.state not in [LSState.LSSTATE_WAIT_IDENT, LSState.LSSTATE_WAIT_VERS]: + self._protocol_error(LSState.LSSTATE_WAIT_IDENT.name) + + if not self.allow_all and identity.hash not in self.allowed_identity_hashes and identity.hash not in self.allowed_file_identity_hashes: + self.terminate("Identity is not allowed.") + + self.remote_identity = identity + self._set_state(LSState.LSSTATE_WAIT_VERS) + + @classmethod + async def pump_all(cls) -> True: + processed_any = False + for session in cls.sessions: + processed = session.pump() + processed_any = processed_any or processed + await asyncio.sleep(0) + + + @classmethod + async def terminate_all(cls, reason: str): + for session in cls.sessions: + session.terminate(reason) + await asyncio.sleep(0) + + def pump(self) -> bool: + def compress_adaptive(buf: bytes): + comp_tries = RNS.RawChannelWriter.COMPRESSION_TRIES + comp_try = 1 + comp_success = False + + chunk_len = len(buf) + if chunk_len > RNS.RawChannelWriter.MAX_CHUNK_LEN: + chunk_len = RNS.RawChannelWriter.MAX_CHUNK_LEN + chunk_segment = None + + chunk_segment = None + max_data_len = self.channel.mdu - protocol.StreamDataMessage.OVERHEAD + while chunk_len > 32 and comp_try < comp_tries: + chunk_segment_length = int(chunk_len/comp_try) + compressed_chunk = bz2.compress(buf[:chunk_segment_length]) + compressed_length = len(compressed_chunk) + if compressed_length < max_data_len and compressed_length < chunk_segment_length: + comp_success = True + break + else: + comp_try += 1 + + if comp_success: + diff = max_data_len - len(compressed_chunk) + chunk = compressed_chunk + processed_length = chunk_segment_length + else: + chunk = bytes(buf[:max_data_len]) + processed_length = len(chunk) + + return comp_success, processed_length, chunk + + try: + if self.state != LSState.LSSTATE_RUNNING: + return False + elif not self.channel.is_ready_to_send(): + return False + elif len(self.stderr_buf) > 0: + comp_success, processed_length, data = compress_adaptive(self.stderr_buf) + self.stderr_buf = self.stderr_buf[processed_length:] + send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent + self.stderr_eof_sent = self.stderr_eof_sent or send_eof + msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDERR, + data, send_eof, comp_success) + self.send(msg) + if send_eof: + self.stderr_eof_sent = True + return True + elif len(self.stdout_buf) > 0: + comp_success, processed_length, data = compress_adaptive(self.stdout_buf) + self.stdout_buf = self.stdout_buf[processed_length:] + send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent + self.stdout_eof_sent = self.stdout_eof_sent or send_eof + msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDOUT, + data, send_eof, comp_success) + self.send(msg) + if send_eof: + self.stdout_eof_sent = True + return True + elif self.return_code is not None and not self.return_code_sent: + msg = protocol.CommandExitedMessage(self.return_code) + self.send(msg) + self.return_code_sent = True + self._call(functools.partial(self._check_protocol_timeout, + lambda: self.state == LSState.LSSTATE_RUNNING, "CommandExitedMessage"), + max(self.outlet.rtt * 5, 10)) + return False + + except Exception as e: RNS.log(f"Error during pump: {e}", RNS.LOG_ERROR) + return False + + def _terminate_process(self): + with contextlib.suppress(Exception): + if self.process and self.process.running: + self.process.terminate() + + def _start_cmd(self, cmdline: [str], pipe_stdin: bool, pipe_stdout: bool, pipe_stderr: bool, tcflags: [any], + term: str | None, rows: int, cols: int, hpix: int, vpix: int): + + self.cmdline = self.default_command + if not self.allow_remote_command and cmdline and len(cmdline) > 0: + self.terminate("Remote command line not allowed by listener") + return + + if self.remote_cmd_as_args and cmdline and len(cmdline) > 0: + self.cmdline.extend(cmdline) + elif cmdline and len(cmdline) > 0: + self.cmdline = cmdline + + + self.stdin_is_pipe = pipe_stdin + self.stdout_is_pipe = pipe_stdout + self.stderr_is_pipe = pipe_stderr + self.tcflags = tcflags + self.term = term + + def stdout(data: bytes): + self.stdout_buf.extend(data) + + def stderr(data: bytes): + self.stderr_buf.extend(data) + + try: + self.process = process.CallbackSubprocess(argv=self.cmdline, + env={"TERM": self.term or os.environ.get("TERM") or "xterm", + "RNS_REMOTE_IDENTITY": (RNS.prettyhexrep(self.remote_identity.hash) + if self.remote_identity and self.remote_identity.hash else "")}, + loop=self.loop, + stdout_callback=stdout, + stderr_callback=stderr, + terminated_callback=self._terminated, + stdin_is_pipe=self.stdin_is_pipe, + stdout_is_pipe=self.stdout_is_pipe, + stderr_is_pipe=self.stderr_is_pipe) + self.process.start() + self._set_window_size(rows, cols, hpix, vpix) + except Exception as e: + RNS.log(f"Unable to start process for link {self.outlet}: {e}", RNS.LOG_ERROR) + self.terminate("Unable to start process") + + def _set_window_size(self, rows: int, cols: int, hpix: int, vpix: int): + self.rows = rows + self.cols = cols + self.hpix = hpix + self.vpix = vpix + with contextlib.suppress(Exception): + self.process.set_winsize(rows, cols, hpix, vpix) + + def _received_stdin(self, data: bytes, eof: bool): + if data and len(data) > 0: + self.process.write(data) + if eof: + self.process.close_stdin() + + def _handle_message(self, message: RNS.MessageBase): + if self.state == LSState.LSSTATE_WAIT_IDENT: + # Ignore any messages until the initiator has identified to avoid race conditions + # between identity announcement and early protocol messages. + RNS.log("Ignoring message while waiting for identification", RNS.LOG_DEBUG) + return + if self.state == LSState.LSSTATE_WAIT_VERS: + if not isinstance(message, protocol.VersionInfoMessage): + self._protocol_error(self.state.name) + return + RNS.log(f"Version {message.sw_version}, protocol {message.protocol_version} on link {self.outlet}", RNS.LOG_VERBOSE) + if message.protocol_version != protocol.PROTOCOL_VERSION: + self.terminate("Incompatible protocol") + return + self.send(protocol.VersionInfoMessage()) + self._set_state(LSState.LSSTATE_WAIT_CMD) + return + elif self.state == LSState.LSSTATE_WAIT_CMD: + if not isinstance(message, protocol.ExecuteCommandMesssage): + return self._protocol_error(self.state.name) + RNS.log(f"Execute command message on link {self.outlet}: {message.cmdline}", RNS.LOG_VERBOSE) + self._set_state(LSState.LSSTATE_RUNNING) + self._start_cmd(message.cmdline, message.pipe_stdin, message.pipe_stdout, message.pipe_stderr, + message.tcflags, message.term, message.rows, message.cols, message.hpix, message.vpix) + return + elif self.state == LSState.LSSTATE_RUNNING: + if isinstance(message, protocol.WindowSizeMessage): + self._set_window_size(message.rows, message.cols, message.hpix, message.vpix) + elif isinstance(message, protocol.StreamDataMessage): + if message.stream_id != protocol.StreamDataMessage.STREAM_ID_STDIN: + RNS.log(f"Received stream data for invalid stream {message.stream_id} on link {self.outlet}", RNS.LOG_ERROR) + return self._protocol_error(self.state.name) + self._received_stdin(message.data, message.eof) + return + elif isinstance(message, protocol.NoopMessage): + # echo noop only on listener--used for keepalive/connectivity check + self.send(message) + return + elif self.state in [LSState.LSSTATE_ERROR, LSState.LSSTATE_TEARDOWN]: + RNS.log(f"Received packet, but in state {self.state.name}", RNS.LOG_ERROR) + return + else: + self._protocol_error("unexpected message") + return + + +class RNSOutlet(LSOutletBase): + + def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]): + def inner_cb(link, identity: _TIdentity): + cb(self, identity) + + self.link.set_remote_identified_callback(inner_cb) + + def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]): + def inner_cb(link): + cb(self) + + self.link.set_link_closed_callback(inner_cb) + + def unset_link_closed_callback(self): + self.link.set_link_closed_callback(None) + + def teardown(self): + self.link.teardown() + + @property + def rtt(self) -> float: + return self.link.rtt + + def __str__(self): + return f"Outlet RNS Link {self.link}" + + def __init__(self, link: RNS.Link): + self.link = link + link.lsoutlet = self + + @staticmethod + def get_outlet(link: RNS.Link): + if hasattr(link, "lsoutlet"): + return link.lsoutlet + + return RNSOutlet(link) \ No newline at end of file