Added rnsh to included utilities

This commit is contained in:
Mark Qvist 2026-04-26 22:24:00 +02:00
parent 3eee369704
commit 6abb31e469
12 changed files with 2431 additions and 0 deletions

View file

@ -0,0 +1,29 @@
# MIT License
#
# Copyright (c) 2023 Aaron Heise
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from ._version import __version__
import os
module_abs_filename = os.path.abspath(__file__)
module_dir = os.path.dirname(module_abs_filename)
def _get_version(): return __version__

View file

@ -0,0 +1 @@
__version__ = "0.2.0"

View file

@ -0,0 +1,26 @@
import contextlib
from contextlib import AbstractContextManager
import logging
import sys
class permit(AbstractContextManager):
"""Context manager to allow specified exceptions
The specified exceptions will be allowed to bubble up. Other
exceptions are suppressed.
After a non-matching exception is suppressed, execution proceeds
with the next statement following the with statement.
with allow(KeyboardInterrupt):
time.sleep(300)
# Execution still resumes here if no KeyboardInterrupt
"""
def __init__(self, *exceptions): self._exceptions = exceptions
def __enter__(self): pass
def __exit__(self, exctype, excinst, exctb):
return exctype is not None and not issubclass(exctype, self._exceptions)

View file

@ -0,0 +1,25 @@
import asyncio
import time
def bitwise_or_if(value: int, condition: bool, orval: int):
if not condition: return value
return value | orval
def check_and(value: int, andval: int) -> bool:
return (value & andval) > 0
class SleepRate:
def __init__(self, target_period: float):
self.target_period = target_period
self.last_wake = time.time()
def next_sleep_time(self) -> float:
old_last_wake = self.last_wake
self.last_wake = time.time()
next_wake = max(old_last_wake + 0.01, self.last_wake)
sleep_for = next_wake - self.last_wake
return sleep_for if sleep_for > 0 else 0
async def sleep_async(self): await asyncio.sleep(self.next_sleep_time())
def sleep_block(self): time.sleep(self.next_sleep_time())

View file

@ -0,0 +1,474 @@
#!/usr/bin/env python3
# MIT License
#
# Copyright (c) 2016-2022 Mark Qvist / unsigned.io
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import annotations
import asyncio
import base64
import enum
import functools
import os
import queue
import shlex
import signal
import sys
import termios
import threading
import time
import tty
from typing import Callable, TypeVar
import RNS
import rnsh.exception as exception
import rnsh.process as process
import rnsh.retry as retry
import rnsh.session as session
import re
import contextlib
import rnsh.args
import pwd
import bz2
import rnsh.protocol as protocol
import rnsh.helpers as helpers
import rnsh.rnsh
_identity = None
_reticulum = None
_cmd: [str] | None = None
DATA_AVAIL_MSG = "data available"
_finished: asyncio.Event = None
_retry_timer: retry.RetryThread | None = None
_destination: RNS.Destination | None = None
_loop: asyncio.AbstractEventLoop | None = None
async def _check_finished(timeout: float = 0):
return _finished is not None and await process.event_wait(_finished, timeout=timeout)
def _sigint_handler(sig, loop):
global _finished
RNS.log(f"{signal.Signals(sig).name}", RNS.LOG_DEBUG)
if _finished is not None: _finished.set()
else: raise KeyboardInterrupt()
async def _spin_tty(until=None, msg=None, timeout=None):
i = 0
syms = "⢄⢂⢁⡁⡈⡐⡠"
if timeout != None: timeout = time.time()+timeout
print(msg+" ", end=" ")
while (timeout == None or time.time()<timeout) and not until():
await asyncio.sleep(0.1)
print(("\b\b"+syms[i]+" "), end="")
sys.stdout.flush()
i = (i+1)%len(syms)
print("\r"+" "*len(msg)+" \r", end="")
if timeout != None and time.time() > timeout: return False
else: return True
async def _spin_pipe(until: callable = None, msg=None, timeout: float | None = None) -> bool:
if timeout is not None: timeout += time.time()
while (timeout is None or time.time() < timeout) and not until():
if await _check_finished(0.1): raise asyncio.CancelledError()
if timeout is not None and time.time() > timeout: return False
else: return True
async def _spin(until: callable = None, msg=None, timeout: float | None = None, quiet: bool = False) -> bool:
if not quiet and os.isatty(1): return await _spin_tty(until, msg, timeout)
else: return await _spin_pipe(until, msg, timeout)
_link: RNS.Link | None = None
_remote_exec_grace = 2.0
_pq = queue.Queue()
class InitiatorState(enum.IntEnum):
IS_INITIAL = 0
IS_LINKED = 1
IS_WAIT_VERS = 2
IS_RUNNING = 3
IS_TERMINATE = 4
IS_TEARDOWN = 5
def _client_link_closed(link):
if _finished: _finished.set()
def _client_message_handler(message: RNS.MessageBase): _pq.put(message)
def compute_target_rns_loglevel(verbosity: int, quietness: int, base_level: int = RNS.LOG_INFO) -> int:
try:
target = int(base_level) + int(verbosity) - int(quietness)
if target < RNS.LOG_CRITICAL: target = RNS.LOG_CRITICAL
if target > RNS.LOG_DEBUG: target = RNS.LOG_DEBUG
return target
except Exception: return base_level
class RemoteExecutionError(Exception):
def __init__(self, msg): self.msg = msg
async def _initiate_link(configdir, rnsconfigdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None,
timeout=RNS.Transport.PATH_REQUEST_TIMEOUT):
global _identity, _reticulum, _link, _destination, _remote_exec_grace
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2
if len(destination) != dest_len:
raise RemoteExecutionError(
"Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(
hex=dest_len, byte=dest_len // 2))
try:
destination_hash = bytes.fromhex(destination)
except Exception as e:
raise RemoteExecutionError("Invalid destination entered. Check your input.")
if _reticulum is None:
targetloglevel = compute_target_rns_loglevel(verbosity, quietness, RNS.LOG_ERROR)
RNS.logfile = os.path.join(configdir, "logfile")
_reticulum = RNS.Reticulum(configdir=rnsconfigdir, loglevel=targetloglevel, logdest=RNS.LOG_FILE)
if _identity is None:
_identity = rnsh.rnsh.prepare_identity(identitypath)
if not RNS.Transport.has_path(destination_hash):
RNS.Transport.request_path(destination_hash)
RNS.log(f"Requesting path...", RNS.LOG_INFO)
if not await _spin(until=lambda: RNS.Transport.has_path(destination_hash), msg="Requesting path...",
timeout=timeout, quiet=quietness > 0):
raise RemoteExecutionError("Path not found")
if _destination is None:
listener_identity = RNS.Identity.recall(destination_hash)
_destination = RNS.Destination(
listener_identity,
RNS.Destination.OUT,
RNS.Destination.SINGLE,
rnsh.rnsh.APP_NAME
)
if _link is None or _link.status == RNS.Link.PENDING:
RNS.log("No link", RNS.LOG_DEBUG)
_link = RNS.Link(_destination)
_link.did_identify = False
_link.set_link_closed_callback(_client_link_closed)
RNS.log(f"Establishing link...", RNS.LOG_VERBOSE)
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, msg="Establishing link...",
timeout=timeout, quiet=quietness > 0):
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
RNS.log("Have link", RNS.LOG_DEBUG)
if not noid and not _link.did_identify:
# Delay a tiny bit to allow listener to fully enter WAIT_IDENT state
await asyncio.sleep(min(1, _link.rtt * 1.1 + 0.05))
_link.identify(_identity)
_link.did_identify = True
async def _handle_error(errmsg: RNS.MessageBase):
if isinstance(errmsg, protocol.ErrorMessage):
with contextlib.suppress(Exception):
if _link and _link.status == RNS.Link.ACTIVE:
_link.teardown()
await asyncio.sleep(0.1)
raise RemoteExecutionError(f"Remote error: {errmsg.msg}")
async def initiate(configdir: str, rnsconfigdir:str, identitypath: str, verbosity: int, quietness: int, noid: bool, destination: str,
timeout: float, command: [str] | None = None):
global _finished, _link
with process.TTYRestorer(sys.stdin.fileno()) as ttyRestorer:
loop = asyncio.get_running_loop()
state = InitiatorState.IS_INITIAL
data_buffer = bytearray(sys.stdin.buffer.read()) if not os.isatty(sys.stdin.fileno()) else bytearray()
line_buffer = bytearray()
await _initiate_link(configdir=configdir,
rnsconfigdir=rnsconfigdir,
identitypath=identitypath,
verbosity=verbosity,
quietness=quietness,
noid=noid,
destination=destination,
timeout=timeout)
if not _link or _link.status not in [RNS.Link.ACTIVE, RNS.Link.PENDING]:
return 255
state = InitiatorState.IS_LINKED
outlet = session.RNSOutlet(_link)
channel = _link.get_channel()
protocol.register_message_types(channel)
channel.add_message_handler(_client_message_handler)
# Next step after linking and identifying: send version
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5, quiet=quietness > 0):
# print("Error bringing up link")
# return 253
channel.send(protocol.VersionInfoMessage())
try:
vm = _pq.get(timeout=max(outlet.rtt * 20, 5))
await _handle_error(vm)
if not isinstance(vm, protocol.VersionInfoMessage):
raise Exception("Invalid message received")
RNS.log(f"Server version info: sw {vm.sw_version} prot {vm.protocol_version}", RNS.LOG_DEBUG)
state = InitiatorState.IS_RUNNING
except queue.Empty:
print("Protocol error")
return 254
winch = False
def sigwinch_handler():
nonlocal winch
winch = True
esc = False
pre_esc = True
line_mode = False
line_flush = False
blind_write_count = 0
flush_chars = ["\x01", "\x03", "\x04", "\x05", "\x0c", "\x11", "\x13", "\x15", "\x19", "\t", "\x1A", "\x1B"]
def handle_escape(b):
nonlocal line_mode
if b == "?":
os.write(1, "\n\r\n\rSupported rnsh escape sequences:".encode("utf-8"))
os.write(1, "\n\r ~~ Send the escape character by typing it twice".encode("utf-8"))
os.write(1, "\n\r ~. Terminate session and exit immediately".encode("utf-8"))
os.write(1, "\n\r ~L Toggle line-interactive mode".encode("utf-8"))
os.write(1, "\n\r ~? Display this quick reference\n\r".encode("utf-8"))
os.write(1, "\n\r(Escape sequences are only recognized immediately after newline)\n\r".encode("utf-8"))
return None
elif b == ".":
_link.teardown()
return None
elif b == "L":
line_mode = not line_mode
if line_mode:
os.write(1, "\n\rLine-interactive mode enabled\n\r".encode("utf-8"))
else:
os.write(1, "\n\rLine-interactive mode disabled\n\r".encode("utf-8"))
return None
return b
stdin_eof = False
def stdin():
nonlocal stdin_eof, pre_esc, esc, line_mode
nonlocal line_flush, blind_write_count
try:
in_data = process.tty_read(sys.stdin.fileno())
if in_data is not None:
data = bytearray()
for b in bytes(in_data):
c = chr(b)
if c == "\r":
pre_esc = True
line_flush = True
data.append(b)
elif line_mode and c in flush_chars:
pre_esc = False
line_flush = True
data.append(b)
elif line_mode and (c == "\b" or c == "\x7f"):
pre_esc = False
if len(line_buffer)>0:
line_buffer.pop(-1)
blind_write_count -= 1
os.write(1, "\b \b".encode("utf-8"))
elif pre_esc == True and c == "~":
pre_esc = False
esc = True
elif esc == True:
ret = handle_escape(c)
if ret != None:
if ret != "~":
data.append(ord("~"))
data.append(ord(ret))
esc = False
else:
pre_esc = False
data.append(b)
if not line_mode:
data_buffer.extend(data)
else:
line_buffer.extend(data)
if line_flush:
data_buffer.extend(line_buffer)
line_buffer.clear()
os.write(1, ("\b \b"*blind_write_count).encode("utf-8"))
line_flush = False
blind_write_count = 0
else:
os.write(1, data)
blind_write_count += len(data)
except EOFError:
if os.isatty(0):
data_buffer.extend(process.CTRL_D)
stdin_eof = True
process.tty_unset_reader_callbacks(sys.stdin.fileno())
process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
tcattr = None
rows, cols, hpix, vpix = (None, None, None, None)
try:
tcattr = termios.tcgetattr(0)
rows, cols, hpix, vpix = process.tty_get_winsize(0)
except:
try:
tcattr = termios.tcgetattr(1)
rows, cols, hpix, vpix = process.tty_get_winsize(1)
except:
try:
tcattr = termios.tcgetattr(2)
rows, cols, hpix, vpix = process.tty_get_winsize(2)
except:
pass
await _spin(lambda: channel.is_ready_to_send(), "Waiting for channel...", 1, quietness > 0)
channel.send(protocol.ExecuteCommandMesssage(cmdline=command,
pipe_stdin=not os.isatty(0),
pipe_stdout=not os.isatty(1),
pipe_stderr=not os.isatty(2),
tcflags=tcattr,
term=os.environ.get("TERM", None),
rows=rows,
cols=cols,
hpix=hpix,
vpix=vpix))
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
_finished = asyncio.Event()
loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, loop))
loop.add_signal_handler(signal.SIGTERM, functools.partial(_sigint_handler, signal.SIGTERM, loop))
mdu = _link.MDU - 16
sent_eof = False
last_winch = time.time()
sleeper = helpers.SleepRate(0.01)
processed = False
while not await _check_finished() and state in [InitiatorState.IS_RUNNING]:
try:
try:
message = _pq.get(timeout=sleeper.next_sleep_time() if not processed else 0.0005)
await _handle_error(message)
processed = True
if isinstance(message, protocol.StreamDataMessage):
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDOUT:
if message.data and len(message.data) > 0:
ttyRestorer.raw()
RNS.log(f"stdout: {message.data}", RNS.LOG_DEBUG)
os.write(1, message.data)
sys.stdout.flush()
if message.eof:
os.close(1)
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDERR:
if message.data and len(message.data) > 0:
ttyRestorer.raw()
RNS.log(f"stdout: {message.data}", RNS.LOG_DEBUG)
os.write(2, message.data)
sys.stderr.flush()
if message.eof:
os.close(2)
elif isinstance(message, protocol.CommandExitedMessage):
RNS.log(f"received return code {message.return_code}, exiting", RNS.LOG_DEBUG)
return message.return_code
elif isinstance(message, protocol.ErrorMessage):
RNS.log(f"Remote error: {message.data}", RNS.LOG_ERROR)
if message.fatal:
_link.teardown()
return 200
except queue.Empty:
processed = False
if channel.is_ready_to_send():
def compress_adaptive(buf: bytes):
comp_tries = RNS.RawChannelWriter.COMPRESSION_TRIES
comp_try = 1
comp_success = False
chunk_len = len(buf)
if chunk_len > RNS.RawChannelWriter.MAX_CHUNK_LEN:
chunk_len = RNS.RawChannelWriter.MAX_CHUNK_LEN
chunk_segment = None
chunk_segment = None
max_data_len = channel.mdu - protocol.StreamDataMessage.OVERHEAD
while chunk_len > 32 and comp_try < comp_tries:
chunk_segment_length = int(chunk_len/comp_try)
compressed_chunk = bz2.compress(buf[:chunk_segment_length])
compressed_length = len(compressed_chunk)
if compressed_length < max_data_len and compressed_length < chunk_segment_length:
comp_success = True
break
else:
comp_try += 1
if comp_success:
diff = max_data_len - len(compressed_chunk)
chunk = compressed_chunk
processed_length = chunk_segment_length
else:
chunk = bytes(buf[:max_data_len])
processed_length = len(chunk)
return comp_success, processed_length, chunk
comp_success, processed_length, chunk = compress_adaptive(data_buffer)
stdin = chunk
data_buffer = data_buffer[processed_length:]
eof = not sent_eof and stdin_eof and len(stdin) == 0
if len(stdin) > 0 or eof:
channel.send(protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, stdin, eof, comp_success))
sent_eof = eof
processed = True
# send window change, but rate limited
if winch and time.time() - last_winch > _link.rtt * 25:
last_winch = time.time()
winch = False
with contextlib.suppress(Exception):
r, c, h, v = process.tty_get_winsize(0)
channel.send(protocol.WindowSizeMessage(r, c, h, v))
processed = True
except RemoteExecutionError as e:
print(e.msg)
return 255
except Exception as ex:
print(f"Client exception: {ex}")
if _link and _link.status != RNS.Link.CLOSED:
_link.teardown()
return 127
RNS.log("Main loop done", RNS.LOG_DEBUG)
return 0

View file

@ -0,0 +1,219 @@
#!/usr/bin/env python3
# MIT License
#
# Copyright (c) 2016-2022 Mark Qvist / unsigned.io
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import annotations
import asyncio
import os
import queue
import shlex
import signal
import sys
import termios
import threading
import time
import tty
from typing import Callable, TypeVar
import RNS
import rnsh.exception as exception
import rnsh.process as process
import rnsh.retry as retry
import rnsh.session as session
import re
import contextlib
import rnsh.args
import pwd
import rnsh.protocol as protocol
import rnsh.helpers as helpers
import rnsh.rnsh
_identity = None
_reticulum = None
_allow_all = False
_allowed_file = None
_allowed_identity_hashes = []
_allowed_file_identity_hashes = []
_cmd: [str] | None = None
DATA_AVAIL_MSG = "data available"
_finished: asyncio.Event = None
_retry_timer: retry.RetryThread | None = None
_destination: RNS.Destination | None = None
_loop: asyncio.AbstractEventLoop | None = None
_no_remote_command = True
_remote_cmd_as_args = False
async def _check_finished(timeout: float = 0):
return await process.event_wait(_finished, timeout=timeout)
def _sigint_handler(sig, loop):
global _finished
RNS.log(f"Signal: {signal.Signals(sig).name}", RNS.LOG_DEBUG)
if _finished is not None: _finished.set()
else: raise KeyboardInterrupt()
def _reload_allowed_file():
global _allowed_file, _allowed_file_identity_hashes
if _allowed_file != None:
try:
with open(_allowed_file, "r") as file:
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2
added = 0
line = 0
_allowed_file_identity_hashes = []
for allow in file.read().replace("\r", "").split("\n"):
line += 1
if len(allow) == dest_len:
try:
destination_hash = bytes.fromhex(allow)
_allowed_file_identity_hashes.append(destination_hash)
added += 1
except Exception:
RNS.log(f"Discarded invalid Identity hash in {_allowed_file} at line {line}", RNS.LOG_DEBUG)
ms = "y" if added == 1 else "ies"
RNS.log(f"Loaded {added} allowed identit{ms} from "+str(_allowed_file), RNS.LOG_DEBUG)
except Exception as e: RNS.log(f"Error while reloading allowed indetities file: {e}", RNS.LOG_ERROR)
def compute_target_rns_loglevel(verbosity: int, quietness: int, base_level: int = RNS.LOG_INFO) -> int:
try:
target = int(base_level) + int(verbosity) - int(quietness)
if target < RNS.LOG_CRITICAL: target = RNS.LOG_CRITICAL
if target > RNS.LOG_DEBUG: target = RNS.LOG_DEBUG
return target
except Exception: return base_level
async def listen(configdir, rnsconfigdir, command, identitypath=None, service_name=None, verbosity=0, quietness=0, allowed=None,
allowed_file=None, disable_auth=None, announce_period=900, no_remote_command=True, remote_cmd_as_args=False,
loop: asyncio.AbstractEventLoop = None):
global _identity, _allow_all, _allowed_identity_hashes, _allowed_file, _allowed_file_identity_hashes
global _reticulum, _cmd, _destination, _no_remote_command, _remote_cmd_as_args, _finished
if not loop: loop = asyncio.get_running_loop()
if service_name is None or len(service_name) == 0:
service_name = "default"
RNS.log(f"Using service name {service_name}", RNS.LOG_INFO)
# More -v should increase verbosity (higher RNS.loglevel); -q should decrease it
targetloglevel = compute_target_rns_loglevel(verbosity, quietness, RNS.LOG_INFO)
_reticulum = RNS.Reticulum(configdir=rnsconfigdir, loglevel=targetloglevel)
_identity = rnsh.rnsh.prepare_identity(identitypath, service_name)
_destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, rnsh.rnsh.APP_NAME)
RNS.log(f"rnsh listening for commands on {RNS.prettyhexrep(_destination.hash)}", RNS.LOG_NOTICE)
_cmd = command
if _cmd is None or len(_cmd) == 0:
shell = None
try: shell = pwd.getpwuid(os.getuid()).pw_shell
except Exception as e: RNS.log(f"Error looking up shell: {e}", RNS.LOG_ERROR)
RNS.log(f"Using {shell} for default command.", RNS.LOG_INFO)
# Ensure a sane shell default. Fall back to /bin/sh if lookup fails.
if not shell or len(shell) == 0: shell = "/bin/sh"
_cmd = [shell]
else: RNS.log(f"Using command {shlex.join(_cmd)}", RNS.LOG_INFO)
_no_remote_command = no_remote_command
session.ListenerSession.allow_remote_command = not no_remote_command
_remote_cmd_as_args = remote_cmd_as_args
if (_cmd is None or len(_cmd) == 0 or _cmd[0] is None or len(_cmd[0]) == 0) \
and (_no_remote_command or _remote_cmd_as_args):
raise Exception(f"Unable to look up shell for {os.getlogin}, cannot proceed with -A or -C and no <program>.")
session.ListenerSession.default_command = _cmd
session.ListenerSession.remote_cmd_as_args = _remote_cmd_as_args
if disable_auth:
_allow_all = True
session.ListenerSession.allow_all = True
else:
if allowed_file is not None:
_allowed_file = allowed_file
_reload_allowed_file()
if allowed is not None:
for a in allowed:
try:
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2
if len(a) != dest_len:
raise ValueError(
"Allowed destination length is invalid, must be {hex} hexadecimal " +
"characters ({byte} bytes).".format(
hex=dest_len, byte=dest_len // 2))
try:
destination_hash = bytes.fromhex(a)
_allowed_identity_hashes.append(destination_hash)
session.ListenerSession.allowed_identity_hashes.append(destination_hash)
except Exception:
raise ValueError("Invalid destination entered. Check your input.")
except Exception as e:
RNS.log(f"Unhandled error: {e}", RNS.LOG_ERROR)
RNS.trace_exception(e)
exit(1)
if (len(_allowed_identity_hashes) < 1 and len(_allowed_file_identity_hashes) < 1) and not disable_auth:
RNS.log("Warning: No allowed identities configured, rnsh will not accept any connections!", RNS.LOG_WARNING)
def link_established(lnk: RNS.Link):
_reload_allowed_file()
session.ListenerSession.allowed_file_identity_hashes = _allowed_file_identity_hashes
session.ListenerSession(session.RNSOutlet.get_outlet(lnk), lnk.get_channel(), loop)
_destination.set_link_established_callback(link_established)
_finished = asyncio.Event()
signal.signal(signal.SIGINT, _sigint_handler)
if announce_period is not None: _destination.announce()
last_announce = time.time()
sleeper = helpers.SleepRate(0.01)
try:
while not await _check_finished():
if announce_period and 0 < announce_period < time.time() - last_announce:
last_announce = time.time()
_destination.announce()
if len(session.ListenerSession.sessions) > 0:
# no sleep if there's work to do
if not await session.ListenerSession.pump_all():
await sleeper.sleep_async()
else:
await asyncio.sleep(0.25)
finally:
RNS.log("Shutting down", RNS.LOG_NOTICE)
await session.ListenerSession.terminate_all("Shutting down")
await asyncio.sleep(1)
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
for link in links_still_active:
if link.status not in [RNS.Link.CLOSED]:
link.teardown()
await asyncio.sleep(0.01)

View file

@ -0,0 +1,12 @@
import asyncio
import functools
from typing import Callable
def sig_handler_sys_to_loop(handler: Callable[[int, any], None]) -> Callable[[int, asyncio.AbstractEventLoop], None]:
def wrapped(cb: Callable[[int, any], None], signal: int, loop: asyncio.AbstractEventLoop): cb(signal, None)
return functools.partial(wrapped, handler)
def loop_set_signal(sig, handler: Callable[[int, asyncio.AbstractEventLoop], None], loop: asyncio.AbstractEventLoop = None):
if loop is None: loop = asyncio.get_running_loop()
loop.remove_signal_handler(sig)
loop.add_signal_handler(sig, functools.partial(handler, sig, loop))

View file

@ -0,0 +1,773 @@
# MIT License
#
# Copyright (c) 2023 Aaron Heise
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import annotations
import asyncio
import contextlib
import copy
import errno
import fcntl
import functools
import os
import pty
import select
import signal
import struct
import sys
import termios
import threading
import tty
import types
import typing
import RNS
import rnsh.exception as exception
CTRL_C = "\x03".encode("utf-8")
CTRL_D = "\x04".encode("utf-8")
def tty_add_reader_callback(fd: int, callback: callable, loop: asyncio.AbstractEventLoop = None):
"""
Add an async reader callback for a tty file descriptor.
Example usage:
def reader():
data = tty_read(fd)
# do something with data
tty_add_reader_callback(self._child_fd, reader, self._loop)
:param fd: file descriptor
:param callback: callback function
:param loop: asyncio event loop to which the reader should be added. If None, use the currently-running loop.
"""
if loop is None:
loop = asyncio.get_running_loop()
loop.add_reader(fd, callback)
def tty_read(fd: int) -> bytes:
"""
Read available bytes from a tty file descriptor. When used in a callback added to a file descriptor using
tty_add_reader_callback(...), this function creates a solution for non-blocking reads from ttys.
:param fd: tty file descriptor
:return: bytes read
"""
if fd_is_closed(fd):
raise EOFError
try:
run = True
result = bytearray()
while not fd_is_closed(fd):
ready, _, _ = select.select([fd], [], [], 0)
if len(ready) == 0:
break
for f in ready:
try:
data = os.read(f, 4096)
except OSError as e:
if e.errno != errno.EIO and e.errno != errno.EWOULDBLOCK:
raise
else:
if not data: # EOF
if data is not None and len(data) > 0:
result.extend(data)
return result
elif len(result) > 0:
return result
else:
raise EOFError
if data is not None and len(data) > 0:
result.extend(data)
return result
except EOFError: raise
except Exception as e: RNS.log(f"TTY read error: {e}", RNS.LOG_ERROR)
def tty_read_poll(fd: int) -> bytes:
"""
Read available bytes from a tty file descriptor. When used in a callback added to a file descriptor using
tty_add_reader_callback(...), this function creates a solution for non-blocking reads from ttys.
:param fd: tty file descriptor
:return: bytes read
"""
if fd_is_closed(fd):
raise EOFError
result = bytearray()
try:
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
while True:
try:
data = os.read(fd, 4096)
if not data:
# EOF
if len(result) > 0:
return result
raise EOFError
result.extend(data)
# continue loop to drain
except OSError as e:
if e.errno in (errno.EWOULDBLOCK, errno.EAGAIN):
break
if e.errno == errno.EIO:
if len(result) > 0:
return result
raise EOFError
raise
except EOFError: raise
except Exception as e: RNS.log(f"TTY read error: {e}", RNS.LOG_ERROR)
return result
def fd_is_closed(fd: int) -> bool:
"""
Check if file descriptor is closed
:param fd: file descriptor
:return: True if file descriptor is closed
"""
try:
fcntl.fcntl(fd, fcntl.F_GETFL) < 0
except OSError as ose:
return ose.errno == errno.EBADF
def tty_unset_reader_callbacks(fd: int, loop: asyncio.AbstractEventLoop = None):
"""
Remove async reader callbacks for file descriptor.
:param fd: file descriptor
:param loop: asyncio event loop from which to remove callbacks
"""
with exception.permit(SystemExit):
if loop is None:
loop = asyncio.get_running_loop()
loop.remove_reader(fd)
def tty_get_winsize(fd: int) -> [int, int, int, int]:
"""
Ge the window size of a tty.
:param fd: file descriptor of tty
:return: (rows, cols, h_pixels, v_pixels)
"""
packed = fcntl.ioctl(fd, termios.TIOCGWINSZ, struct.pack('HHHH', 0, 0, 0, 0))
rows, cols, h_pixels, v_pixels = struct.unpack('HHHH', packed)
return rows, cols, h_pixels, v_pixels
def tty_set_winsize(fd: int, rows: int, cols: int, h_pixels: int, v_pixels: int):
"""
Set the window size on a tty.
:param fd: file descriptor of tty
:param rows: number of visible rows
:param cols: number of visible columns
:param h_pixels: number of visible horizontal pixels
:param v_pixels: number of visible vertical pixels
"""
if fd < 0:
return
packed = struct.pack('HHHH', rows, cols, h_pixels, v_pixels)
fcntl.ioctl(fd, termios.TIOCSWINSZ, packed)
def process_exists(pid) -> bool:
"""
Check For the existence of a unix pid.
:param pid: process id to check
:return: True if process exists
"""
try:
os.kill(pid, 0)
except OSError:
return False
else:
return True
class TTYRestorer(contextlib.AbstractContextManager):
# Indexes of flags within the attrs array
ATTR_IDX_IFLAG = 0
ATTR_IDX_OFLAG = 1
ATTR_IDX_CFLAG = 2
ATTR_IDX_LFLAG = 4
ATTR_IDX_CC = 5
def __init__(self, fd: int, suppress_logs=False):
"""
Saves termios attributes for a tty for later restoration.
The attributes are an array of values with the following meanings.
tcflag_t c_iflag; /* input modes */
tcflag_t c_oflag; /* output modes */
tcflag_t c_cflag; /* control modes */
tcflag_t c_lflag; /* local modes */
cc_t c_cc[NCCS]; /* special characters */
:param fd: file descriptor of tty
"""
self._fd = fd
self._tattr = None
self._suppress_logs = suppress_logs
self._tattr = self.current_attr()
if not self._tattr and not self._suppress_logs: RNS.log(f"Could not get attrs for fd {fd}", RNS.LOG_DEBUG)
def raw(self):
"""
Set raw mode on tty
"""
if self._fd is None:
return
with contextlib.suppress(termios.error):
tty.setraw(self._fd, termios.TCSANOW)
def original_attr(self) -> [any]:
return copy.deepcopy(self._tattr)
def current_attr(self) -> [any]:
"""
Get the current termios attributes for the wrapped fd.
:return: attribute array
"""
if self._fd is None:
return None
with contextlib.suppress(termios.error):
return copy.deepcopy(termios.tcgetattr(self._fd))
return None
def set_attr(self, attr: [any], when: int = termios.TCSADRAIN):
"""
Set termios attributes
:param attr: attribute list to set
:param when: when attributes should be applied (termios.TCSANOW, termios.TCSADRAIN, termios.TCSAFLUSH)
"""
if not attr or self._fd is None:
return
with contextlib.suppress(termios.error):
termios.tcsetattr(self._fd, when, attr)
def isatty(self):
return os.isatty(self._fd) if self._fd is not None else None
def restore(self):
"""
Restore termios settings to state captured in constructor.
"""
self.set_attr(self._tattr, termios.TCSADRAIN)
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
__traceback: types.TracebackType) -> bool:
self.restore()
return False #__exc_type is not None and issubclass(__exc_type, termios.error)
def _task_from_event(evt: asyncio.Event, loop: asyncio.AbstractEventLoop = None):
if not loop:
loop = asyncio.get_running_loop()
#TODO: this is hacky
async def wait():
while not evt.is_set():
await asyncio.sleep(0.1)
return True
return loop.create_task(wait())
class AggregateException(Exception):
def __init__(self, inner_exceptions: [Exception]):
super().__init__()
self.inner_exceptions = inner_exceptions
def __str__(self):
return "Multiple exceptions encountered: \n\n" + "\n\n".join(map(lambda e: str(e), self.inner_exceptions))
async def event_wait_any(evts: [asyncio.Event], timeout: float = None) -> (any, any):
tasks = list(map(lambda evt: (evt, _task_from_event(evt)), evts))
try:
finished, unfinished = await asyncio.wait(map(lambda t: t[1], tasks),
timeout=timeout,
return_when=asyncio.FIRST_COMPLETED)
if len(unfinished) > 0:
for task in unfinished:
task.cancel()
await asyncio.wait(unfinished)
exceptions = []
for f in finished:
ex = f.exception()
if ex and not isinstance(ex, asyncio.CancelledError) and not isinstance(ex, TimeoutError):
exceptions.append(ex)
if len(exceptions) > 0:
raise AggregateException(exceptions)
return next(map(lambda t: next(map(lambda tt: tt[0], tasks)), finished), None)
finally:
unfinished = []
for task in map(lambda t: t[1], tasks):
if task.done():
if not task.cancelled():
task.exception()
else:
task.cancel()
unfinished.append(task)
if len(unfinished) > 0:
await asyncio.wait(unfinished)
async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
"""
Wait for event to be set, or timeout to expire.
:param evt: asyncio.Event to wait on
:param timeout: maximum number of seconds to wait.
:return: True if event was set, False if timeout expired
"""
await event_wait_any([evt], timeout=timeout)
return evt.is_set()
def _launch_child(cmd_line: list[str], env: dict[str, str], stdin_is_pipe: bool, stdout_is_pipe: bool,
stderr_is_pipe: bool) -> tuple[int, int, int, int]:
# Set up PTY and/or pipes
child_fd = parent_fd = None
if not (stdin_is_pipe and stdout_is_pipe and stderr_is_pipe):
parent_fd, child_fd = pty.openpty()
child_stdin, parent_stdin = (os.pipe() if stdin_is_pipe else (child_fd, parent_fd))
parent_stdout, child_stdout = (os.pipe() if stdout_is_pipe else (parent_fd, child_fd))
parent_stderr, child_stderr = (os.pipe() if stderr_is_pipe else (parent_fd, child_fd))
# Fork
pid = os.fork()
if pid == 0:
try:
# We are in the child process, so close all open sockets and pipes except for the PTY and/or pipes
max_fd = os.sysconf("SC_OPEN_MAX")
for fd in range(3, max_fd):
if fd not in (child_stdin, child_stdout, child_stderr):
try:
os.close(fd)
except OSError:
pass
# Set up PTY and/or pipes
os.dup2(child_stdin, 0)
os.dup2(child_stdout, 1)
os.dup2(child_stderr, 2)
# Make PTY controlling if necessary so that CTRL_C/CTRL_D behave as expected
if child_fd is not None:
os.setsid()
try:
tty_fd = 0 if not stdin_is_pipe else (1 if not stdout_is_pipe else 2)
# Set controlling TTY for this session
fcntl.ioctl(tty_fd, termios.TIOCSCTTY, 0)
except Exception:
pass
# Ensure the child is the foreground process group for the TTY
try:
os.setpgid(0, 0)
pgid = os.getpgrp()
import struct as _struct
fcntl.ioctl(tty_fd, termios.TIOCSPGRP, _struct.pack('i', pgid))
except Exception:
pass
# Ensure canonical input with signals and local echo enabled
try:
tty_fd = 0 if not stdin_is_pipe else (1 if not stdout_is_pipe else 2)
attrs = termios.tcgetattr(tty_fd)
lflag = attrs[3]
lflag |= termios.ICANON | termios.ISIG | termios.ECHO
attrs[3] = lflag
termios.tcsetattr(tty_fd, termios.TCSANOW, attrs)
except Exception:
pass
# Execute the command
os.execvpe(cmd_line[0], cmd_line, env)
except Exception as err:
exc_type, exc_obj, exc_tb = sys.exc_info()
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
print(f"Unable to start {cmd_line[0]}: {err} ({fname}:{exc_tb.tb_lineno})")
sys.stdout.flush()
# don't let any other modules get in our way, do an immediate silent exit.
os._exit(255)
else:
# We are in the parent process, so close the child-side of the PTY and/or pipes
if child_fd is not None:
os.close(child_fd)
if child_stdin != child_fd:
os.close(child_stdin)
if child_stdout != child_fd:
os.close(child_stdout)
if child_stderr != child_fd:
os.close(child_stderr)
# # Close the write end of the pipe if a pipe is used for standard input
# if not stdin_is_pipe:
# os.close(parent_stdin)
# Return the child PID and the file descriptors for the PTY and/or pipes
return pid, parent_stdin, parent_stdout, parent_stderr
class CallbackSubprocess:
# time between checks of child process
PROCESS_POLL_TIME: float = 0.1
# Close pipes soon after process exit to avoid scheduling on closed event loops
PROCESS_PIPE_TIME: int = 1
def __init__(self, argv: [str], env: dict, loop: asyncio.AbstractEventLoop, stdout_callback: callable,
stderr_callback: callable, terminated_callback: callable, stdin_is_pipe: bool, stdout_is_pipe: bool,
stderr_is_pipe: bool):
"""
Fork a child process and generate callbacks with output from the process.
:param argv: the command line, tokenized. The first element must be the absolute path to an executable file.
:param env: environment variables to override
:param loop: the asyncio event loop to use
:param stdout_callback: callback for data, e.g. def callback(data:bytes) -> None
:param terminated_callback: callback for termination/return code, e.g. def callback(return_code:int) -> None
"""
assert loop is not None, "loop should not be None"
assert stdout_callback is not None, "stdout_callback should not be None"
assert terminated_callback is not None, "terminated_callback should not be None"
self._command: [str] = argv
self._env = env or {}
self._loop = loop
self._stdout_cb = stdout_callback
self._stderr_cb = stderr_callback
self._terminated_cb = terminated_callback
self._pid: int = None
self._child_stdin: int = None
self._child_stdout: int = None
self._child_stderr: int = None
self._return_code: int = None
self._stdout_eof: bool = False
self._stderr_eof: bool = False
self._stdin_is_pipe = stdin_is_pipe
self._stdout_is_pipe = stdout_is_pipe
self._stderr_is_pipe = stderr_is_pipe
self._at_line_start: bool = True
self._tty_line_buffer: bytearray = bytearray()
def _ensure_pipes_closed(self):
stdin = self._child_stdin
stdout = self._child_stdout
stderr = self._child_stderr
fds = set(filter(lambda x: x is not None, list({stdin, stdout, stderr})))
RNS.log(f"Queuing close of pipes for ended process (fds: {fds})", RNS.LOG_DEBUG)
def ensure_pipes_closed_inner():
RNS.log(f"Ensuring pipes are closed (fds: {fds})", RNS.LOG_DEBUG)
for fd in fds:
RNS.log(f"Closing fd {fd}", RNS.LOG_DEBUG)
with contextlib.suppress(OSError): tty_unset_reader_callbacks(fd)
with contextlib.suppress(OSError): os.close(fd)
self._child_stdin = None
self._child_stdout = None
self._child_stderr = None
# Avoid scheduling on a closed loop
if self._loop.is_closed(): ensure_pipes_closed_inner()
else: self._loop.call_later(CallbackSubprocess.PROCESS_PIPE_TIME, ensure_pipes_closed_inner)
def terminate(self, kill_delay: float = 1.0):
"""
Terminate child process if running
:param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL
"""
RNS.log("terminate()", RNS.LOG_EXTREME)
if not self.running: return
with exception.permit(SystemExit): os.kill(self._pid, signal.SIGTERM)
def kill():
if process_exists(self._pid):
RNS.log("kill()", RNS.LOG_EXTREME)
with exception.permit(SystemExit):
os.kill(self._pid, signal.SIGHUP)
os.kill(self._pid, signal.SIGKILL)
self._loop.call_later(kill_delay, kill)
def wait():
RNS.log("wait()", RNS.LOG_EXTREME)
with contextlib.suppress(OSError): os.waitpid(self._pid, 0)
self._ensure_pipes_closed()
RNS.log("wait() finish", RNS.LOG_EXTREME)
threading.Thread(target=wait, daemon=True).start()
def close_stdin(self):
with contextlib.suppress(Exception):
os.close(self._child_stdin)
# Encourage prompt shutdown if child lingers after stdin close
def _ensure_terminate():
if self.running:
self.terminate(kill_delay=0.2)
if not self._loop.is_closed():
self._loop.call_later(0.05, _ensure_terminate)
@property
def started(self) -> bool:
"""
:return: True if child process has been started
"""
return self._pid is not None
@property
def running(self) -> bool:
"""
:return: True if child process is still running
"""
return self._pid is not None and process_exists(self._pid)
def write(self, data: bytes):
"""
Write bytes to the stdin of the child process.
:param data: bytes to write
"""
os.write(self._child_stdin, data)
# TODO: Check what this is actually supposed to solve.
#
# For pipe-in + TTY-out, echo should be visible immediately
if self._stdin_is_pipe and not self._stdout_is_pipe and self._stdout_cb is not None and data not in (CTRL_C, CTRL_D):
try: self._stdout_cb(data)
except Exception: pass
def set_winsize(self, r: int, c: int, h: int, v: int):
"""
Set the window size on the tty of the child process.
:param r: rows visible
:param c: columns visible
:param h: horizontal pixels visible
:param v: vertical pixels visible
:return:
"""
RNS.log(f"set_winsize({r},{c},{h},{v}", RNS.LOG_DEBUG)
tty_set_winsize(self._child_stdout, r, c, h, v)
def copy_winsize(self, fromfd: int):
"""
Copy window size from one tty to another.
:param fromfd: source tty file descriptor
"""
r, c, h, v = tty_get_winsize(fromfd)
self.set_winsize(r, c, h, v)
def tcsetattr(self, when: int, attr: list[any]): # actual type is list[int | list[int | bytes]]
"""
Set tty attributes.
:param when: when to apply change: termios.TCSANOW or termios.TCSADRAIN or termios.TCSAFLUSH
:param attr: attributes to set
"""
termios.tcsetattr(self._child_stdin, when, attr)
def tcgetattr(self) -> list[any]: # actual type is list[int | list[int | bytes]]
"""
Get tty attributes.
:return: tty attributes value
"""
return termios.tcgetattr(self._child_stdout)
def ttysetraw(self):
tty.setraw(self._child_stdout, termios.TCSADRAIN)
def start(self):
"""
Start the child process.
"""
RNS.log("start()", RNS.LOG_EXTREME)
# # Using the parent environment seems to do some weird stuff, at least on macOS
# parentenv = os.environ.copy()
# env = {"HOME": parentenv["HOME"],
# "PATH": parentenv["PATH"],
# "TERM": self._term if self._term is not None else parentenv.get("TERM", "xterm"),
# "LANG": parentenv.get("LANG"),
# "SHELL": self._command[0]}
env = os.environ.copy()
for key in self._env:
env[key] = self._env[key]
program = self._command[0]
assert isinstance(program, str)
# match = re.search("^/bin/(.*sh)$", program)
# if match:
# self._command[0] = "-" + match.group(1)
# env["SHELL"] = program
# self._log.debug(f"set login shell {self._command}")
self._pid, \
self._child_stdin, \
self._child_stdout, \
self._child_stderr = _launch_child(self._command, env, self._stdin_is_pipe, self._stdout_is_pipe,
self._stderr_is_pipe)
RNS.log(f"Started pid {self.pid}, fds: {self._child_stdin}, {self._child_stdout}, {self._child_stderr}", RNS.LOG_DEBUG)
def poll():
try:
pid, self._return_code = os.waitpid(self._pid, os.WNOHANG)
if self._return_code is not None:
self._return_code = self._return_code & 0xff
if self._return_code is not None and not process_exists(self._pid):
RNS.log(f"polled return code {self._return_code}", RNS.LOG_DEBUG)
self._terminated_cb(self._return_code)
if self.running:
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
else:
self._ensure_pipes_closed()
except Exception as e:
if not hasattr(e, "errno") or e.errno != errno.ECHILD:
RNS.log(f"Error in process poll: {e}", RNS.LOG_DEBUG)
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
def stdout():
try:
with exception.permit(SystemExit):
data = tty_read_poll(self._child_stdout)
if data is not None and len(data) > 0:
self._stdout_cb(data)
# Opportunistically drain shortly after to coalesce immediate follow-up output
if not self._loop.is_closed():
self._loop.call_later(0.01, stdout)
except EOFError:
self._stdout_eof = True
tty_unset_reader_callbacks(self._child_stdout)
self._stdout_cb(bytearray())
def stderr():
try:
with exception.permit(SystemExit):
data = tty_read_poll(self._child_stderr)
if data is not None and len(data) > 0:
self._stderr_cb(data)
if not self._loop.is_closed():
self._loop.call_later(0.01, stderr)
except EOFError:
self._stderr_eof = True
tty_unset_reader_callbacks(self._child_stderr)
self._stderr_cb(bytearray())
tty_add_reader_callback(self._child_stdout, stdout, self._loop)
if self._child_stderr != self._child_stdout:
tty_add_reader_callback(self._child_stderr, stderr, self._loop)
@property
def stdout_eof(self):
return self._stdout_eof or not self.running
@property
def stderr_eof(self):
return self._stderr_eof or not self.running
@property
def return_code(self) -> int:
return self._return_code
@property
def pid(self) -> int:
return self._pid
async def main():
"""
A test driver for the CallbackProcess class.
python ./process.py /bin/zsh --login
"""
if len(sys.argv) <= 1:
print(f"Usage: {sys.argv} <absolute_path_to_child_executable> [child_arg ...]")
exit(1)
loop = asyncio.get_event_loop()
# asyncio.set_event_loop(loop)
retcode = loop.create_future()
def stdout(data: bytes): os.write(sys.stdout.fileno(), data)
def terminated(rc: int): retcode.set_result(rc)
process = CallbackSubprocess(argv=sys.argv[1:],
env={"TERM": os.environ.get("TERM", "xterm")},
loop=loop,
stdout_callback=stdout,
terminated_callback=terminated)
def sigint_handler(sig, frame):
if process is None or process.started and not process.running:
raise KeyboardInterrupt
elif process.running:
process.write("\x03".encode("utf-8"))
def sigwinch_handler(sig, frame):
process.copy_winsize(sys.stdin.fileno())
signal.signal(signal.SIGINT, sigint_handler)
signal.signal(signal.SIGWINCH, sigwinch_handler)
def stdin():
try:
data = tty_read(sys.stdin.fileno())
if data is not None:
process.write(data)
except EOFError:
tty_unset_reader_callbacks(sys.stdin.fileno())
process.write(CTRL_D)
tty_add_reader_callback(sys.stdin.fileno(), stdin)
process.start()
# call_soon called it too soon, not sure why.
loop.call_later(0.001, functools.partial(process.copy_winsize, sys.stdin.fileno()))
val = await retcode
RNS.log(f"Got return code {val}", RNS.LOG_DEBUG)
return val
if __name__ == "__main__":
tr = TTYRestorer(sys.stdin.fileno())
try:
tr.raw()
asyncio.run(main())
finally:
tty_unset_reader_callbacks(sys.stdin.fileno())
tr.restore()

View file

@ -0,0 +1,115 @@
from __future__ import annotations
import RNS
from RNS.vendor import umsgpack
from RNS.Buffer import StreamDataMessage as RNSStreamDataMessage
import rnsh.retry
import abc
import contextlib
import struct
from abc import ABC, abstractmethod
MSG_MAGIC = 0xac
PROTOCOL_VERSION = 1
def _make_MSGTYPE(val: int):
return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff)
class NoopMessage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(0)
def pack(self) -> bytes: return bytes()
def unpack(self, raw): pass
class WindowSizeMessage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(2)
def __init__(self, rows: int = None, cols: int = None, hpix: int = None, vpix: int = None):
super().__init__()
self.rows = rows
self.cols = cols
self.hpix = hpix
self.vpix = vpix
def pack(self) -> bytes: return umsgpack.packb((self.rows, self.cols, self.hpix, self.vpix))
def unpack(self, raw): self.rows, self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
class ExecuteCommandMesssage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(3)
def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False,
pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None, rows: int = None,
cols: int = None, hpix: int = None, vpix: int = None):
super().__init__()
self.cmdline = cmdline
self.pipe_stdin = pipe_stdin
self.pipe_stdout = pipe_stdout
self.pipe_stderr = pipe_stderr
self.tcflags = tcflags
self.term = term
self.rows = rows
self.cols = cols
self.hpix = hpix
self.vpix = vpix
def pack(self) -> bytes:
return umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix))
def unpack(self, raw):
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \
self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
# Create a version of RNS.Buffer.StreamDataMessage that we control
class StreamDataMessage(RNSStreamDataMessage):
MSGTYPE = _make_MSGTYPE(4)
STREAM_ID_STDIN = 0
STREAM_ID_STDOUT = 1
STREAM_ID_STDERR = 2
class VersionInfoMessage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(5)
def __init__(self, sw_version: str = None):
super().__init__()
self.sw_version = sw_version or rnsh.__version__
self.protocol_version = PROTOCOL_VERSION
def pack(self) -> bytes: return umsgpack.packb((self.sw_version, self.protocol_version))
def unpack(self, raw): self.sw_version, self.protocol_version = umsgpack.unpackb(raw)
class ErrorMessage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(6)
def __init__(self, msg: str = None, fatal: bool = False, data: dict = None):
super().__init__()
self.msg = msg
self.fatal = fatal
self.data = data
def pack(self) -> bytes: return umsgpack.packb((self.msg, self.fatal, self.data))
def unpack(self, raw: bytes): self.msg, self.fatal, self.data = umsgpack.unpackb(raw)
class CommandExitedMessage(RNS.MessageBase):
MSGTYPE = _make_MSGTYPE(7)
def __init__(self, return_code: int = None):
super().__init__()
self.return_code = return_code
def pack(self) -> bytes: return umsgpack.packb(self.return_code)
def unpack(self, raw: bytes): self.return_code = umsgpack.unpackb(raw)
message_types = [NoopMessage, VersionInfoMessage, WindowSizeMessage, ExecuteCommandMesssage, StreamDataMessage,
CommandExitedMessage, ErrorMessage]
def register_message_types(channel: RNS.Channel.Channel):
for message_type in message_types: channel.register_message_type(message_type)

189
RNS/Utilities/rnsh/retry.py Normal file
View file

@ -0,0 +1,189 @@
# MIT License
#
# Copyright (c) 2023 Aaron Heise
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
import asyncio
import threading
import time
import rnsh.exception as exception
from typing import Callable
from contextlib import AbstractContextManager
import types
import typing
class RetryStatus:
def __init__(self, tag: any, try_limit: int, wait_delay: float, retry_callback: Callable[[any, int], any],
timeout_callback: Callable[[any, int], None], tries: int = 1):
self.tag = tag
self.try_limit = try_limit
self.tries = tries
self.wait_delay = wait_delay
self.retry_callback = retry_callback
self.timeout_callback = timeout_callback
self.try_time = time.time()
self.completed = False
@property
def ready(self):
ready = time.time() > self.try_time + self.wait_delay
RNS.log(f"ready check {self.tag} try_time {self.try_time} wait_delay {self.wait_delay} " +
f"next_try {self.try_time + self.wait_delay} now {time.time()} " +
f"exceeded {time.time() - self.try_time - self.wait_delay} ready {ready}", RNS.LOG_DEBUG)
return ready
@property
def timed_out(self):
return self.ready and self.tries >= self.try_limit
def timeout(self):
self.completed = True
self.timeout_callback(self.tag, self.tries)
def retry(self) -> any:
self.tries = self.tries + 1
self.try_time = time.time()
return self.retry_callback(self.tag, self.tries)
class RetryThread(AbstractContextManager):
def __init__(self, loop_period: float = 0.25, name: str = "retry thread"):
self._loop_period = loop_period
self._statuses: list[RetryStatus] = []
self._tag_counter = 0
self._lock = threading.RLock()
self._run = True
self._finished: asyncio.Future = None
self._thread = threading.Thread(name=name, target=self._thread_run, daemon=True)
self._thread.start()
def is_alive(self):
return self._thread.is_alive()
def close(self, loop: asyncio.AbstractEventLoop = None) -> asyncio.Future:
RNS.log("Stopping timer thread", RNS.LOG_DEBUG)
if loop is None:
self._run = False
self._thread.join()
return None
else:
self._finished = loop.create_future()
return self._finished
def wait(self, timeout: float = None):
if timeout:
timeout = timeout + time.time()
while timeout is None or time.time() < timeout:
with self._lock:
task_count = len(self._statuses)
if task_count == 0:
return
time.sleep(0.1)
def _thread_run(self):
while self._run and self._finished is None:
time.sleep(self._loop_period)
ready: list[RetryStatus] = []
prune: list[RetryStatus] = []
with self._lock: ready.extend(list(filter(lambda s: s.ready, self._statuses)))
for retry in ready:
try:
if not retry.completed:
if retry.timed_out:
RNS.log(f"Timed out {retry.tag} after {retry.try_limit} tries", RNS.LOG_DEBUG)
retry.timeout()
prune.append(retry)
elif retry.ready:
RNS.log(f"Retrying {retry.tag}, try {retry.tries + 1}/{retry.try_limit}", RNS.LOG_DEBUG)
should_continue = retry.retry()
if not should_continue: self.complete(retry.tag)
except Exception as e:
RNS.log(f"Error processing retry id {retry.tag}: {e}", RNS.LOG_ERROR)
prune.append(retry)
with self._lock:
for retry in prune:
RNS.log(f"pruned retry {retry.tag}, retry count {retry.tries}/{retry.try_limit}", RNS.LOG_DEBUG)
with exception.permit(SystemExit): self._statuses.remove(retry)
if self._finished is not None: self._finished.set_result(None)
def _get_next_tag(self):
self._tag_counter += 1
return self._tag_counter
def has_tag(self, tag: any) -> bool:
with self._lock: return next(filter(lambda s: s.tag == tag, self._statuses), None) is not None
def begin(self, try_limit: int, wait_delay: float, try_callback: Callable[[any, int], any],
timeout_callback: Callable[[any, int], None]) -> any:
RNS.log(f"Running first try", RNS.LOG_DEBUG)
tag = try_callback(None, 1)
RNS.log(f"First try got id {tag}", RNS.LOG_DEBUG)
if not tag:
RNS.log(f"Callback returned None/False/0, considering complete.", RNS.LOG_DEBUG)
return None
with self._lock:
if tag is None: tag = self._get_next_tag()
self.complete(tag)
self._statuses.append(RetryStatus(tag=tag,
tries=1,
try_limit=try_limit,
wait_delay=wait_delay,
retry_callback=try_callback,
timeout_callback=timeout_callback))
RNS.log(f"Added retry timer for {tag}", RNS.LOG_DEBUG)
return tag
def complete(self, tag: any):
assert tag is not None
with self._lock:
status = next(filter(lambda l: l.tag == tag, self._statuses), None)
if status is not None:
status.completed = True
self._statuses.remove(status)
RNS.log(f"completed {tag}", RNS.LOG_DEBUG)
return
RNS.log(f"status not found to complete {tag}", RNS.LOG_DEBUG)
def complete_all(self):
with self._lock:
for status in self._statuses:
status.completed = True
RNS.log(f"completed {status.tag}", RNS.LOG_DEBUG)
self._statuses.clear()
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
__traceback: types.TracebackType) -> bool:
self.close()
return False

161
RNS/Utilities/rnsh/rnsh.py Normal file
View file

@ -0,0 +1,161 @@
#!/usr/bin/env python3
# MIT License
#
# Copyright (c) 2016-2022 Mark Qvist / unsigned.io
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import annotations
import asyncio
import base64
import re
import os
import sys
import RNS
import rnsh.process as process
import rnsh.session as session
import rnsh.args
import rnsh.loop
import rnsh.listener as listener
import rnsh.initiator as initiator
APP_NAME = "rnsh"
loop: asyncio.AbstractEventLoop | None = None
def _sanitize_service_name(service_name:str) -> str: return re.sub(r'\W+', '', service_name)
def prepare_identity(identity_path, service_name: str = None) -> tuple[RNS.Identity]:
service_name = _sanitize_service_name(service_name or "")
if identity_path is None:
identity_path = RNS.Reticulum.identitypath + "/" + APP_NAME + \
(f".{service_name}" if service_name and len(service_name) > 0 else "")
identity = None
if os.path.isfile(identity_path):
identity = RNS.Identity.from_file(identity_path)
if identity is None:
RNS.log("No valid saved identity found, creating new...", RNS.LOG_INFO)
identity = RNS.Identity()
identity.to_file(identity_path)
return identity
def print_identity(configdir, identitypath, service_name, include_destination: bool):
reticulum = RNS.Reticulum(configdir=configdir, loglevel=RNS.LOG_INFO)
if service_name and len(service_name) > 0:
print(f"Using service name \"{service_name}\"")
identity = prepare_identity(identitypath, service_name)
destination = RNS.Destination(identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME)
print("Identity : " + str(identity))
if include_destination:
print("Listening on : " + RNS.prettyhexrep(destination.hash))
exit(0)
verbose_set = False
def ensure_config_directory():
if os.path.isdir(os.path.expanduser("~/.config/rnsh")): return os.path.expanduser("~/.config/rnsh")
elif os.path.isdir(os.path.expanduser("~/.rnsh")): return os.path.expanduser("~/.rnsh")
else:
try:
os.makedirs(os.path.expanduser("~/.rnsh"))
return os.path.expanduser("~/.rnsh")
except Exception as e:
RNS.log(f"Could not get or create rnsh configuration directory, aborting", RNS.LOG_CRITICAL)
os._exit(1)
async def _rnsh_cli_main():
global verbose_set
args = rnsh.args.Args(sys.argv)
verbose_set = args.verbose > 0
configdir = ensure_config_directory()
if args.print_identity:
print_identity(args.config, args.identity, args.service_name, args.listen)
return 0
if args.listen:
allowed_file = None
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2
if os.path.isfile(os.path.expanduser("~/.config/rnsh/allowed_identities")):
allowed_file = os.path.expanduser("~/.config/rnsh/allowed_identities")
elif os.path.isfile(os.path.expanduser("~/.rnsh/allowed_identities")):
allowed_file = os.path.expanduser("~/.rnsh/allowed_identities")
await listener.listen(configdir=configdir,
rnsconfigdir=args.config,
command=args.command_line,
identitypath=args.identity,
service_name=args.service_name,
verbosity=args.verbose,
quietness=args.quiet,
allowed=args.allowed,
allowed_file=allowed_file,
disable_auth=args.no_auth,
announce_period=args.announce,
no_remote_command=args.no_remote_cmd,
remote_cmd_as_args=args.remote_cmd_as_args)
return 0
if args.destination is not None:
return_code = await initiator.initiate(configdir=configdir,
rnsconfigdir=args.config,
identitypath=args.identity,
verbosity=args.verbose,
quietness=args.quiet,
noid=args.no_id,
destination=args.destination,
timeout=args.timeout,
command=args.command_line
)
return return_code if args.mirror else 0
else:
print("")
print(rnsh.args.usage)
print("")
return 1
def main():
global verbose_set
return_code = 1
exc = None
try: return_code = asyncio.run(_rnsh_cli_main())
except SystemExit: pass
except KeyboardInterrupt: pass
except Exception as ex:
print(f"Unhandled exception: {ex}")
exc = ex
process.tty_unset_reader_callbacks(0)
if verbose_set and exc: raise exc
sys.exit(return_code if return_code is not None else 255)
if __name__ == "__main__": main()

View file

@ -0,0 +1,407 @@
from __future__ import annotations
import contextlib
import functools
import rnsh.exception as exception
import asyncio
import rnsh.process as process
import rnsh.helpers as helpers
import rnsh.protocol as protocol
import enum
from typing import TypeVar, Generic, Callable, List
from abc import abstractmethod, ABC
from multiprocessing import Manager
import os
import bz2
import RNS
_TLink = TypeVar("_TLink")
_TIdentity = TypeVar("_TIdentity")
class SEType(enum.IntEnum):
SE_LINK_CLOSED = 0
class SessionException(Exception):
def __init__(self, setype: SEType, msg: str, *args):
super().__init__(msg, args)
self.type = setype
class LSState(enum.IntEnum):
LSSTATE_WAIT_IDENT = 1
LSSTATE_WAIT_VERS = 2
LSSTATE_WAIT_CMD = 3
LSSTATE_RUNNING = 4
LSSTATE_ERROR = 5
LSSTATE_TEARDOWN = 6
class LSOutletBase(ABC):
@abstractmethod
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]): raise NotImplemented()
@abstractmethod
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]): raise NotImplemented()
@abstractmethod
def unset_link_closed_callback(self): raise NotImplemented()
@property
@abstractmethod
def rtt(self): raise NotImplemented()
@abstractmethod
def teardown(self): raise NotImplemented()
class ListenerSession:
sessions: List[ListenerSession] = []
allowed_identity_hashes: [any] = []
allowed_file_identity_hashes: [any] = []
allow_all: bool = False
allow_remote_command: bool = False
default_command: [str] = []
remote_cmd_as_args = False
def __init__(self, outlet: LSOutletBase, channel: RNS.Channel.Channel, loop: asyncio.AbstractEventLoop):
RNS.log(f"Session started for {outlet}", RNS.LOG_INFO)
self.outlet = outlet
self.channel = channel
self.outlet.set_initiator_identified_callback(self._initiator_identified)
self.outlet.set_link_closed_callback(self._link_closed)
self.loop = loop
self.state: LSState = None
self.remote_identity = None
self.term: str | None = None
self.stdin_is_pipe: bool = False
self.stdout_is_pipe: bool = False
self.stderr_is_pipe: bool = False
self.tcflags: [any] = None
self.cmdline: [str] = None
self.rows: int = 0
self.cols: int = 0
self.hpix: int = 0
self.vpix: int = 0
self.stdout_buf = bytearray()
self.stdout_eof_sent = False
self.stderr_buf = bytearray()
self.stderr_eof_sent = False
self.return_code: int | None = None
self.return_code_sent = False
self.process: process.CallbackSubprocess | None = None
if self.allow_all: self._set_state(LSState.LSSTATE_WAIT_VERS)
else: self._set_state(LSState.LSSTATE_WAIT_IDENT)
self.sessions.append(self)
protocol.register_message_types(self.channel)
self.channel.add_message_handler(self._handle_message)
def _terminated(self, return_code: int):
self.return_code = return_code
def _set_state(self, state: LSState, timeout_factor: float = 10.0):
timeout = max(self.outlet.rtt * timeout_factor, max(self.outlet.rtt * 2, 10)) if timeout_factor is not None else None
RNS.log(f"Set state: {state.name}, timeout {timeout}", RNS.LOG_DEBUG)
orig_state = self.state
self.state = state
if timeout_factor is not None:
self._call(functools.partial(self._check_protocol_timeout, lambda: self.state == orig_state, state.name), timeout)
def _call(self, func: callable, delay: float = 0):
def call_inner():
if delay == 0: func()
else: self.loop.call_later(delay, func)
self.loop.call_soon_threadsafe(call_inner)
def send(self, message: RNS.MessageBase):
self.channel.send(message)
def _protocol_error(self, name: str):
self.terminate(f"Protocol error ({name})")
def _protocol_timeout_error(self, name: str):
self.terminate(f"Protocol timeout error: {name}")
def terminate(self, error: str = None):
with contextlib.suppress(Exception):
RNS.log("Terminating session" + (f": {error}" if error else ""), RNS.LOG_DEBUG)
if error and self.state != LSState.LSSTATE_TEARDOWN:
with contextlib.suppress(Exception):
self.send(protocol.ErrorMessage(error, True))
self.state = LSState.LSSTATE_ERROR
self._terminate_process()
self._call(self._prune, max(self.outlet.rtt * 3, process.CallbackSubprocess.PROCESS_PIPE_TIME+5))
def _prune(self):
self.state = LSState.LSSTATE_TEARDOWN
RNS.log("Pruning session", RNS.LOG_DEBUG)
with contextlib.suppress(ValueError):
self.sessions.remove(self)
with contextlib.suppress(Exception):
self.outlet.teardown()
def _check_protocol_timeout(self, fail_condition: Callable[[], bool], name: str):
timeout = True
try: timeout = self.state != LSState.LSSTATE_TEARDOWN and fail_condition()
except Exception as e: RNS.log(f"Error in protocol timeout: {e}", RNS.LOG_ERROR)
if timeout: self._protocol_timeout_error(name)
def _link_closed(self, outlet: LSOutletBase):
outlet.unset_link_closed_callback()
if outlet != self.outlet:
RNS.log("Link closed received from incorrect outlet", RNS.LOG_DEBUG)
return
RNS.log(f"link_closed {outlet}", RNS.LOG_DEBUG)
self.terminate()
def _initiator_identified(self, outlet, identity):
if outlet != self.outlet:
RNS.log("Identity received from incorrect outlet", RNS.LOG_DEBUG)
return
RNS.log(f"initiator_identified {identity} on link {outlet}", RNS.LOG_INFO)
if self.state not in [LSState.LSSTATE_WAIT_IDENT, LSState.LSSTATE_WAIT_VERS]:
self._protocol_error(LSState.LSSTATE_WAIT_IDENT.name)
if not self.allow_all and identity.hash not in self.allowed_identity_hashes and identity.hash not in self.allowed_file_identity_hashes:
self.terminate("Identity is not allowed.")
self.remote_identity = identity
self._set_state(LSState.LSSTATE_WAIT_VERS)
@classmethod
async def pump_all(cls) -> True:
processed_any = False
for session in cls.sessions:
processed = session.pump()
processed_any = processed_any or processed
await asyncio.sleep(0)
@classmethod
async def terminate_all(cls, reason: str):
for session in cls.sessions:
session.terminate(reason)
await asyncio.sleep(0)
def pump(self) -> bool:
def compress_adaptive(buf: bytes):
comp_tries = RNS.RawChannelWriter.COMPRESSION_TRIES
comp_try = 1
comp_success = False
chunk_len = len(buf)
if chunk_len > RNS.RawChannelWriter.MAX_CHUNK_LEN:
chunk_len = RNS.RawChannelWriter.MAX_CHUNK_LEN
chunk_segment = None
chunk_segment = None
max_data_len = self.channel.mdu - protocol.StreamDataMessage.OVERHEAD
while chunk_len > 32 and comp_try < comp_tries:
chunk_segment_length = int(chunk_len/comp_try)
compressed_chunk = bz2.compress(buf[:chunk_segment_length])
compressed_length = len(compressed_chunk)
if compressed_length < max_data_len and compressed_length < chunk_segment_length:
comp_success = True
break
else:
comp_try += 1
if comp_success:
diff = max_data_len - len(compressed_chunk)
chunk = compressed_chunk
processed_length = chunk_segment_length
else:
chunk = bytes(buf[:max_data_len])
processed_length = len(chunk)
return comp_success, processed_length, chunk
try:
if self.state != LSState.LSSTATE_RUNNING:
return False
elif not self.channel.is_ready_to_send():
return False
elif len(self.stderr_buf) > 0:
comp_success, processed_length, data = compress_adaptive(self.stderr_buf)
self.stderr_buf = self.stderr_buf[processed_length:]
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
self.stderr_eof_sent = self.stderr_eof_sent or send_eof
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDERR,
data, send_eof, comp_success)
self.send(msg)
if send_eof:
self.stderr_eof_sent = True
return True
elif len(self.stdout_buf) > 0:
comp_success, processed_length, data = compress_adaptive(self.stdout_buf)
self.stdout_buf = self.stdout_buf[processed_length:]
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
self.stdout_eof_sent = self.stdout_eof_sent or send_eof
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDOUT,
data, send_eof, comp_success)
self.send(msg)
if send_eof:
self.stdout_eof_sent = True
return True
elif self.return_code is not None and not self.return_code_sent:
msg = protocol.CommandExitedMessage(self.return_code)
self.send(msg)
self.return_code_sent = True
self._call(functools.partial(self._check_protocol_timeout,
lambda: self.state == LSState.LSSTATE_RUNNING, "CommandExitedMessage"),
max(self.outlet.rtt * 5, 10))
return False
except Exception as e: RNS.log(f"Error during pump: {e}", RNS.LOG_ERROR)
return False
def _terminate_process(self):
with contextlib.suppress(Exception):
if self.process and self.process.running:
self.process.terminate()
def _start_cmd(self, cmdline: [str], pipe_stdin: bool, pipe_stdout: bool, pipe_stderr: bool, tcflags: [any],
term: str | None, rows: int, cols: int, hpix: int, vpix: int):
self.cmdline = self.default_command
if not self.allow_remote_command and cmdline and len(cmdline) > 0:
self.terminate("Remote command line not allowed by listener")
return
if self.remote_cmd_as_args and cmdline and len(cmdline) > 0:
self.cmdline.extend(cmdline)
elif cmdline and len(cmdline) > 0:
self.cmdline = cmdline
self.stdin_is_pipe = pipe_stdin
self.stdout_is_pipe = pipe_stdout
self.stderr_is_pipe = pipe_stderr
self.tcflags = tcflags
self.term = term
def stdout(data: bytes):
self.stdout_buf.extend(data)
def stderr(data: bytes):
self.stderr_buf.extend(data)
try:
self.process = process.CallbackSubprocess(argv=self.cmdline,
env={"TERM": self.term or os.environ.get("TERM") or "xterm",
"RNS_REMOTE_IDENTITY": (RNS.prettyhexrep(self.remote_identity.hash)
if self.remote_identity and self.remote_identity.hash else "")},
loop=self.loop,
stdout_callback=stdout,
stderr_callback=stderr,
terminated_callback=self._terminated,
stdin_is_pipe=self.stdin_is_pipe,
stdout_is_pipe=self.stdout_is_pipe,
stderr_is_pipe=self.stderr_is_pipe)
self.process.start()
self._set_window_size(rows, cols, hpix, vpix)
except Exception as e:
RNS.log(f"Unable to start process for link {self.outlet}: {e}", RNS.LOG_ERROR)
self.terminate("Unable to start process")
def _set_window_size(self, rows: int, cols: int, hpix: int, vpix: int):
self.rows = rows
self.cols = cols
self.hpix = hpix
self.vpix = vpix
with contextlib.suppress(Exception):
self.process.set_winsize(rows, cols, hpix, vpix)
def _received_stdin(self, data: bytes, eof: bool):
if data and len(data) > 0:
self.process.write(data)
if eof:
self.process.close_stdin()
def _handle_message(self, message: RNS.MessageBase):
if self.state == LSState.LSSTATE_WAIT_IDENT:
# Ignore any messages until the initiator has identified to avoid race conditions
# between identity announcement and early protocol messages.
RNS.log("Ignoring message while waiting for identification", RNS.LOG_DEBUG)
return
if self.state == LSState.LSSTATE_WAIT_VERS:
if not isinstance(message, protocol.VersionInfoMessage):
self._protocol_error(self.state.name)
return
RNS.log(f"Version {message.sw_version}, protocol {message.protocol_version} on link {self.outlet}", RNS.LOG_VERBOSE)
if message.protocol_version != protocol.PROTOCOL_VERSION:
self.terminate("Incompatible protocol")
return
self.send(protocol.VersionInfoMessage())
self._set_state(LSState.LSSTATE_WAIT_CMD)
return
elif self.state == LSState.LSSTATE_WAIT_CMD:
if not isinstance(message, protocol.ExecuteCommandMesssage):
return self._protocol_error(self.state.name)
RNS.log(f"Execute command message on link {self.outlet}: {message.cmdline}", RNS.LOG_VERBOSE)
self._set_state(LSState.LSSTATE_RUNNING)
self._start_cmd(message.cmdline, message.pipe_stdin, message.pipe_stdout, message.pipe_stderr,
message.tcflags, message.term, message.rows, message.cols, message.hpix, message.vpix)
return
elif self.state == LSState.LSSTATE_RUNNING:
if isinstance(message, protocol.WindowSizeMessage):
self._set_window_size(message.rows, message.cols, message.hpix, message.vpix)
elif isinstance(message, protocol.StreamDataMessage):
if message.stream_id != protocol.StreamDataMessage.STREAM_ID_STDIN:
RNS.log(f"Received stream data for invalid stream {message.stream_id} on link {self.outlet}", RNS.LOG_ERROR)
return self._protocol_error(self.state.name)
self._received_stdin(message.data, message.eof)
return
elif isinstance(message, protocol.NoopMessage):
# echo noop only on listener--used for keepalive/connectivity check
self.send(message)
return
elif self.state in [LSState.LSSTATE_ERROR, LSState.LSSTATE_TEARDOWN]:
RNS.log(f"Received packet, but in state {self.state.name}", RNS.LOG_ERROR)
return
else:
self._protocol_error("unexpected message")
return
class RNSOutlet(LSOutletBase):
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
def inner_cb(link, identity: _TIdentity):
cb(self, identity)
self.link.set_remote_identified_callback(inner_cb)
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]):
def inner_cb(link):
cb(self)
self.link.set_link_closed_callback(inner_cb)
def unset_link_closed_callback(self):
self.link.set_link_closed_callback(None)
def teardown(self):
self.link.teardown()
@property
def rtt(self) -> float:
return self.link.rtt
def __str__(self):
return f"Outlet RNS Link {self.link}"
def __init__(self, link: RNS.Link):
self.link = link
link.lsoutlet = self
@staticmethod
def get_outlet(link: RNS.Link):
if hasattr(link, "lsoutlet"):
return link.lsoutlet
return RNSOutlet(link)