Reticulum/RNS/Utilities/rnsh/initiator.py
2026-04-27 00:06:33 +02:00

474 lines
No EOL
20 KiB
Python

#!/usr/bin/env python3
# MIT License
#
# Copyright (c) 2016-2022 Mark Qvist / unsigned.io
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in all
# copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
from __future__ import annotations
import asyncio
import base64
import enum
import functools
import os
import queue
import shlex
import signal
import sys
import termios
import threading
import time
import tty
from typing import Callable, TypeVar
import RNS
import RNS.Utilities.rnsh.exception as exception
import RNS.Utilities.rnsh.process as process
import RNS.Utilities.rnsh.retry as retry
import RNS.Utilities.rnsh.session as session
import re
import contextlib
import pwd
import bz2
import RNS.Utilities.rnsh.protocol as protocol
import RNS.Utilities.rnsh.helpers as helpers
import RNS.Utilities.rnsh.rnsh as rnsh
_identity = None
_reticulum = None
_cmd: [str] | None = None
DATA_AVAIL_MSG = "data available"
_finished: asyncio.Event = None
_retry_timer: retry.RetryThread | None = None
_destination: RNS.Destination | None = None
_loop: asyncio.AbstractEventLoop | None = None
async def _check_finished(timeout: float = 0):
return _finished is not None and await process.event_wait(_finished, timeout=timeout)
def _sigint_handler(sig, loop):
global _finished
RNS.log(f"{signal.Signals(sig).name}", RNS.LOG_DEBUG)
if _finished is not None: _finished.set()
else: raise KeyboardInterrupt()
async def _spin_tty(until=None, msg=None, timeout=None):
i = 0
syms = "⢄⢂⢁⡁⡈⡐⡠"
if timeout != None: timeout = time.time()+timeout
print(msg+" ", end=" ")
while (timeout == None or time.time()<timeout) and not until():
await asyncio.sleep(0.1)
print(("\b\b"+syms[i]+" "), end="")
sys.stdout.flush()
i = (i+1)%len(syms)
print("\r"+" "*len(msg)+" \r", end="")
if timeout != None and time.time() > timeout: return False
else: return True
async def _spin_pipe(until: callable = None, msg=None, timeout: float | None = None) -> bool:
if timeout is not None: timeout += time.time()
while (timeout is None or time.time() < timeout) and not until():
if await _check_finished(0.1): raise asyncio.CancelledError()
if timeout is not None and time.time() > timeout: return False
else: return True
async def _spin(until: callable = None, msg=None, timeout: float | None = None, quiet: bool = False) -> bool:
if not quiet and os.isatty(1): return await _spin_tty(until, msg, timeout)
else: return await _spin_pipe(until, msg, timeout)
_link: RNS.Link | None = None
_remote_exec_grace = 2.0
_pq = queue.Queue()
class InitiatorState(enum.IntEnum):
IS_INITIAL = 0
IS_LINKED = 1
IS_WAIT_VERS = 2
IS_RUNNING = 3
IS_TERMINATE = 4
IS_TEARDOWN = 5
def _client_link_closed(link):
if _finished: _finished.set()
def _client_message_handler(message: RNS.MessageBase): _pq.put(message)
def compute_target_rns_loglevel(verbosity: int, quietness: int, base_level: int = RNS.LOG_INFO) -> int:
try:
target = int(base_level) + int(verbosity) - int(quietness)
if target < RNS.LOG_CRITICAL: target = RNS.LOG_CRITICAL
if target > RNS.LOG_DEBUG: target = RNS.LOG_DEBUG
return target
except Exception: return base_level
class RemoteExecutionError(Exception):
def __init__(self, msg): self.msg = msg
async def _initiate_link(configdir, rnsconfigdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None,
timeout=RNS.Transport.PATH_REQUEST_TIMEOUT):
global _identity, _reticulum, _link, _destination, _remote_exec_grace
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2
if len(destination) != dest_len:
raise RemoteExecutionError(
"Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(
hex=dest_len, byte=dest_len // 2))
try:
destination_hash = bytes.fromhex(destination)
except Exception as e:
raise RemoteExecutionError("Invalid destination entered. Check your input.")
if _reticulum is None:
targetloglevel = compute_target_rns_loglevel(verbosity, quietness, RNS.LOG_ERROR)
RNS.logfile = os.path.join(configdir, "logfile")
_reticulum = RNS.Reticulum(configdir=rnsconfigdir, loglevel=targetloglevel, logdest=RNS.LOG_FILE)
if _identity is None:
_identity = rnsh.prepare_identity(identitypath)
if not RNS.Transport.has_path(destination_hash):
RNS.Transport.request_path(destination_hash)
RNS.log(f"Requesting path...", RNS.LOG_INFO)
if not await _spin(until=lambda: RNS.Transport.has_path(destination_hash), msg="Requesting path...",
timeout=timeout, quiet=quietness > 0):
raise RemoteExecutionError("Path not found")
if _destination is None:
listener_identity = RNS.Identity.recall(destination_hash)
_destination = RNS.Destination(
listener_identity,
RNS.Destination.OUT,
RNS.Destination.SINGLE,
rnsh.APP_NAME
)
if _link is None or _link.status == RNS.Link.PENDING:
RNS.log("No link", RNS.LOG_DEBUG)
_link = RNS.Link(_destination)
_link.did_identify = False
_link.set_link_closed_callback(_client_link_closed)
RNS.log(f"Establishing link...", RNS.LOG_VERBOSE)
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, msg="Establishing link...",
timeout=timeout, quiet=quietness > 0):
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
RNS.log("Have link", RNS.LOG_DEBUG)
if not noid and not _link.did_identify:
# Delay a tiny bit to allow listener to fully enter WAIT_IDENT state
await asyncio.sleep(min(1, _link.rtt * 1.1 + 0.05))
_link.identify(_identity)
_link.did_identify = True
async def _handle_error(errmsg: RNS.MessageBase):
if isinstance(errmsg, protocol.ErrorMessage):
with contextlib.suppress(Exception):
if _link and _link.status == RNS.Link.ACTIVE:
_link.teardown()
await asyncio.sleep(0.1)
raise RemoteExecutionError(f"Remote error: {errmsg.msg}")
async def initiate(configdir: str, rnsconfigdir:str, identitypath: str, verbosity: int, quietness: int, noid: bool, destination: str,
timeout: float, command: [str] | None = None):
global _finished, _link
with process.TTYRestorer(sys.stdin.fileno()) as ttyRestorer:
loop = asyncio.get_running_loop()
state = InitiatorState.IS_INITIAL
data_buffer = bytearray(sys.stdin.buffer.read()) if not os.isatty(sys.stdin.fileno()) else bytearray()
line_buffer = bytearray()
await _initiate_link(configdir=configdir,
rnsconfigdir=rnsconfigdir,
identitypath=identitypath,
verbosity=verbosity,
quietness=quietness,
noid=noid,
destination=destination,
timeout=timeout)
if not _link or _link.status not in [RNS.Link.ACTIVE, RNS.Link.PENDING]:
return 255
state = InitiatorState.IS_LINKED
outlet = session.RNSOutlet(_link)
channel = _link.get_channel()
protocol.register_message_types(channel)
channel.add_message_handler(_client_message_handler)
# Next step after linking and identifying: send version
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5, quiet=quietness > 0):
# print("Error bringing up link")
# return 253
channel.send(protocol.VersionInfoMessage())
try:
vm = _pq.get(timeout=max(outlet.rtt * 20, 5))
await _handle_error(vm)
if not isinstance(vm, protocol.VersionInfoMessage):
raise Exception("Invalid message received")
RNS.log(f"Server version info: sw {vm.sw_version} prot {vm.protocol_version}", RNS.LOG_DEBUG)
state = InitiatorState.IS_RUNNING
except queue.Empty:
print("Protocol error")
return 254
winch = False
def sigwinch_handler():
nonlocal winch
winch = True
esc = False
pre_esc = True
line_mode = False
line_flush = False
blind_write_count = 0
flush_chars = ["\x01", "\x03", "\x04", "\x05", "\x0c", "\x11", "\x13", "\x15", "\x19", "\t", "\x1A", "\x1B"]
def handle_escape(b):
nonlocal line_mode
if b == "?":
os.write(1, "\n\r\n\rSupported rnsh escape sequences:".encode("utf-8"))
os.write(1, "\n\r ~~ Send the escape character by typing it twice".encode("utf-8"))
os.write(1, "\n\r ~. Terminate session and exit immediately".encode("utf-8"))
os.write(1, "\n\r ~L Toggle line-interactive mode".encode("utf-8"))
os.write(1, "\n\r ~? Display this quick reference\n\r".encode("utf-8"))
os.write(1, "\n\r(Escape sequences are only recognized immediately after newline)\n\r".encode("utf-8"))
return None
elif b == ".":
_link.teardown()
return None
elif b == "L":
line_mode = not line_mode
if line_mode:
os.write(1, "\n\rLine-interactive mode enabled\n\r".encode("utf-8"))
else:
os.write(1, "\n\rLine-interactive mode disabled\n\r".encode("utf-8"))
return None
return b
stdin_eof = False
def stdin():
nonlocal stdin_eof, pre_esc, esc, line_mode
nonlocal line_flush, blind_write_count
try:
in_data = process.tty_read(sys.stdin.fileno())
if in_data is not None:
data = bytearray()
for b in bytes(in_data):
c = chr(b)
if c == "\r":
pre_esc = True
line_flush = True
data.append(b)
elif line_mode and c in flush_chars:
pre_esc = False
line_flush = True
data.append(b)
elif line_mode and (c == "\b" or c == "\x7f"):
pre_esc = False
if len(line_buffer)>0:
line_buffer.pop(-1)
blind_write_count -= 1
os.write(1, "\b \b".encode("utf-8"))
elif pre_esc == True and c == "~":
pre_esc = False
esc = True
elif esc == True:
ret = handle_escape(c)
if ret != None:
if ret != "~":
data.append(ord("~"))
data.append(ord(ret))
esc = False
else:
pre_esc = False
data.append(b)
if not line_mode:
data_buffer.extend(data)
else:
line_buffer.extend(data)
if line_flush:
data_buffer.extend(line_buffer)
line_buffer.clear()
os.write(1, ("\b \b"*blind_write_count).encode("utf-8"))
line_flush = False
blind_write_count = 0
else:
os.write(1, data)
blind_write_count += len(data)
except EOFError:
if os.isatty(0):
data_buffer.extend(process.CTRL_D)
stdin_eof = True
process.tty_unset_reader_callbacks(sys.stdin.fileno())
process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
tcattr = None
rows, cols, hpix, vpix = (None, None, None, None)
try:
tcattr = termios.tcgetattr(0)
rows, cols, hpix, vpix = process.tty_get_winsize(0)
except:
try:
tcattr = termios.tcgetattr(1)
rows, cols, hpix, vpix = process.tty_get_winsize(1)
except:
try:
tcattr = termios.tcgetattr(2)
rows, cols, hpix, vpix = process.tty_get_winsize(2)
except:
pass
await _spin(lambda: channel.is_ready_to_send(), "Waiting for channel...", 1, quietness > 0)
channel.send(protocol.ExecuteCommandMesssage(cmdline=command,
pipe_stdin=not os.isatty(0),
pipe_stdout=not os.isatty(1),
pipe_stderr=not os.isatty(2),
tcflags=tcattr,
term=os.environ.get("TERM", None),
rows=rows,
cols=cols,
hpix=hpix,
vpix=vpix))
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
_finished = asyncio.Event()
loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, loop))
loop.add_signal_handler(signal.SIGTERM, functools.partial(_sigint_handler, signal.SIGTERM, loop))
mdu = _link.MDU - 16
sent_eof = False
last_winch = time.time()
sleeper = helpers.SleepRate(0.01)
processed = False
while not await _check_finished() and state in [InitiatorState.IS_RUNNING]:
try:
try:
message = _pq.get(timeout=sleeper.next_sleep_time() if not processed else 0.0005)
await _handle_error(message)
processed = True
if isinstance(message, protocol.StreamDataMessage):
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDOUT:
if message.data and len(message.data) > 0:
ttyRestorer.raw()
RNS.log(f"stdout: {message.data}", RNS.LOG_DEBUG)
os.write(1, message.data)
sys.stdout.flush()
if message.eof:
os.close(1)
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDERR:
if message.data and len(message.data) > 0:
ttyRestorer.raw()
RNS.log(f"stdout: {message.data}", RNS.LOG_DEBUG)
os.write(2, message.data)
sys.stderr.flush()
if message.eof:
os.close(2)
elif isinstance(message, protocol.CommandExitedMessage):
RNS.log(f"received return code {message.return_code}, exiting", RNS.LOG_DEBUG)
return message.return_code
elif isinstance(message, protocol.ErrorMessage):
RNS.log(f"Remote error: {message.data}", RNS.LOG_ERROR)
if message.fatal:
_link.teardown()
return 200
except queue.Empty:
processed = False
if channel.is_ready_to_send():
def compress_adaptive(buf: bytes):
comp_tries = RNS.RawChannelWriter.COMPRESSION_TRIES
comp_try = 1
comp_success = False
chunk_len = len(buf)
if chunk_len > RNS.RawChannelWriter.MAX_CHUNK_LEN:
chunk_len = RNS.RawChannelWriter.MAX_CHUNK_LEN
chunk_segment = None
chunk_segment = None
max_data_len = channel.mdu - protocol.StreamDataMessage.OVERHEAD
while chunk_len > 32 and comp_try < comp_tries:
chunk_segment_length = int(chunk_len/comp_try)
compressed_chunk = bz2.compress(buf[:chunk_segment_length])
compressed_length = len(compressed_chunk)
if compressed_length < max_data_len and compressed_length < chunk_segment_length:
comp_success = True
break
else:
comp_try += 1
if comp_success:
diff = max_data_len - len(compressed_chunk)
chunk = compressed_chunk
processed_length = chunk_segment_length
else:
chunk = bytes(buf[:max_data_len])
processed_length = len(chunk)
return comp_success, processed_length, chunk
comp_success, processed_length, chunk = compress_adaptive(data_buffer)
stdin = chunk
data_buffer = data_buffer[processed_length:]
eof = not sent_eof and stdin_eof and len(stdin) == 0
if len(stdin) > 0 or eof:
channel.send(protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, stdin, eof, comp_success))
sent_eof = eof
processed = True
# send window change, but rate limited
if winch and time.time() - last_winch > _link.rtt * 25:
last_winch = time.time()
winch = False
with contextlib.suppress(Exception):
r, c, h, v = process.tty_get_winsize(0)
channel.send(protocol.WindowSizeMessage(r, c, h, v))
processed = True
except RemoteExecutionError as e:
print(e.msg)
return 255
except Exception as ex:
print(f"Client exception: {ex}")
if _link and _link.status != RNS.Link.CLOSED:
_link.teardown()
return 127
RNS.log("Main loop done", RNS.LOG_DEBUG)
return 0