Added rnsh to included utilities
This commit is contained in:
parent
3eee369704
commit
6abb31e469
12 changed files with 2431 additions and 0 deletions
29
RNS/Utilities/rnsh/__init__.py
Normal file
29
RNS/Utilities/rnsh/__init__.py
Normal file
|
|
@ -0,0 +1,29 @@
|
|||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 Aaron Heise
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from ._version import __version__
|
||||
|
||||
import os
|
||||
module_abs_filename = os.path.abspath(__file__)
|
||||
module_dir = os.path.dirname(module_abs_filename)
|
||||
|
||||
def _get_version(): return __version__
|
||||
1
RNS/Utilities/rnsh/_version.py
Normal file
1
RNS/Utilities/rnsh/_version.py
Normal file
|
|
@ -0,0 +1 @@
|
|||
__version__ = "0.2.0"
|
||||
26
RNS/Utilities/rnsh/exception.py
Normal file
26
RNS/Utilities/rnsh/exception.py
Normal file
|
|
@ -0,0 +1,26 @@
|
|||
import contextlib
|
||||
from contextlib import AbstractContextManager
|
||||
import logging
|
||||
import sys
|
||||
|
||||
|
||||
class permit(AbstractContextManager):
|
||||
"""Context manager to allow specified exceptions
|
||||
|
||||
The specified exceptions will be allowed to bubble up. Other
|
||||
exceptions are suppressed.
|
||||
|
||||
After a non-matching exception is suppressed, execution proceeds
|
||||
with the next statement following the with statement.
|
||||
|
||||
with allow(KeyboardInterrupt):
|
||||
time.sleep(300)
|
||||
# Execution still resumes here if no KeyboardInterrupt
|
||||
"""
|
||||
|
||||
def __init__(self, *exceptions): self._exceptions = exceptions
|
||||
|
||||
def __enter__(self): pass
|
||||
|
||||
def __exit__(self, exctype, excinst, exctb):
|
||||
return exctype is not None and not issubclass(exctype, self._exceptions)
|
||||
25
RNS/Utilities/rnsh/helpers.py
Normal file
25
RNS/Utilities/rnsh/helpers.py
Normal file
|
|
@ -0,0 +1,25 @@
|
|||
import asyncio
|
||||
import time
|
||||
|
||||
def bitwise_or_if(value: int, condition: bool, orval: int):
|
||||
if not condition: return value
|
||||
return value | orval
|
||||
|
||||
def check_and(value: int, andval: int) -> bool:
|
||||
return (value & andval) > 0
|
||||
|
||||
class SleepRate:
|
||||
def __init__(self, target_period: float):
|
||||
self.target_period = target_period
|
||||
self.last_wake = time.time()
|
||||
|
||||
def next_sleep_time(self) -> float:
|
||||
old_last_wake = self.last_wake
|
||||
self.last_wake = time.time()
|
||||
next_wake = max(old_last_wake + 0.01, self.last_wake)
|
||||
sleep_for = next_wake - self.last_wake
|
||||
return sleep_for if sleep_for > 0 else 0
|
||||
|
||||
async def sleep_async(self): await asyncio.sleep(self.next_sleep_time())
|
||||
|
||||
def sleep_block(self): time.sleep(self.next_sleep_time())
|
||||
474
RNS/Utilities/rnsh/initiator.py
Normal file
474
RNS/Utilities/rnsh/initiator.py
Normal file
|
|
@ -0,0 +1,474 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2016-2022 Mark Qvist / unsigned.io
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
import enum
|
||||
import functools
|
||||
import os
|
||||
import queue
|
||||
import shlex
|
||||
import signal
|
||||
import sys
|
||||
import termios
|
||||
import threading
|
||||
import time
|
||||
import tty
|
||||
from typing import Callable, TypeVar
|
||||
import RNS
|
||||
import rnsh.exception as exception
|
||||
import rnsh.process as process
|
||||
import rnsh.retry as retry
|
||||
import rnsh.session as session
|
||||
import re
|
||||
import contextlib
|
||||
import rnsh.args
|
||||
import pwd
|
||||
import bz2
|
||||
import rnsh.protocol as protocol
|
||||
import rnsh.helpers as helpers
|
||||
import rnsh.rnsh
|
||||
|
||||
_identity = None
|
||||
_reticulum = None
|
||||
_cmd: [str] | None = None
|
||||
DATA_AVAIL_MSG = "data available"
|
||||
_finished: asyncio.Event = None
|
||||
_retry_timer: retry.RetryThread | None = None
|
||||
_destination: RNS.Destination | None = None
|
||||
_loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
|
||||
async def _check_finished(timeout: float = 0):
|
||||
return _finished is not None and await process.event_wait(_finished, timeout=timeout)
|
||||
|
||||
def _sigint_handler(sig, loop):
|
||||
global _finished
|
||||
RNS.log(f"{signal.Signals(sig).name}", RNS.LOG_DEBUG)
|
||||
if _finished is not None: _finished.set()
|
||||
else: raise KeyboardInterrupt()
|
||||
|
||||
async def _spin_tty(until=None, msg=None, timeout=None):
|
||||
i = 0
|
||||
syms = "⢄⢂⢁⡁⡈⡐⡠"
|
||||
if timeout != None: timeout = time.time()+timeout
|
||||
|
||||
print(msg+" ", end=" ")
|
||||
while (timeout == None or time.time()<timeout) and not until():
|
||||
await asyncio.sleep(0.1)
|
||||
print(("\b\b"+syms[i]+" "), end="")
|
||||
sys.stdout.flush()
|
||||
i = (i+1)%len(syms)
|
||||
|
||||
print("\r"+" "*len(msg)+" \r", end="")
|
||||
|
||||
if timeout != None and time.time() > timeout: return False
|
||||
else: return True
|
||||
|
||||
|
||||
async def _spin_pipe(until: callable = None, msg=None, timeout: float | None = None) -> bool:
|
||||
if timeout is not None: timeout += time.time()
|
||||
|
||||
while (timeout is None or time.time() < timeout) and not until():
|
||||
if await _check_finished(0.1): raise asyncio.CancelledError()
|
||||
|
||||
if timeout is not None and time.time() > timeout: return False
|
||||
else: return True
|
||||
|
||||
async def _spin(until: callable = None, msg=None, timeout: float | None = None, quiet: bool = False) -> bool:
|
||||
if not quiet and os.isatty(1): return await _spin_tty(until, msg, timeout)
|
||||
else: return await _spin_pipe(until, msg, timeout)
|
||||
|
||||
_link: RNS.Link | None = None
|
||||
_remote_exec_grace = 2.0
|
||||
_pq = queue.Queue()
|
||||
|
||||
|
||||
class InitiatorState(enum.IntEnum):
|
||||
IS_INITIAL = 0
|
||||
IS_LINKED = 1
|
||||
IS_WAIT_VERS = 2
|
||||
IS_RUNNING = 3
|
||||
IS_TERMINATE = 4
|
||||
IS_TEARDOWN = 5
|
||||
|
||||
def _client_link_closed(link):
|
||||
if _finished: _finished.set()
|
||||
|
||||
def _client_message_handler(message: RNS.MessageBase): _pq.put(message)
|
||||
|
||||
def compute_target_rns_loglevel(verbosity: int, quietness: int, base_level: int = RNS.LOG_INFO) -> int:
|
||||
try:
|
||||
target = int(base_level) + int(verbosity) - int(quietness)
|
||||
if target < RNS.LOG_CRITICAL: target = RNS.LOG_CRITICAL
|
||||
if target > RNS.LOG_DEBUG: target = RNS.LOG_DEBUG
|
||||
return target
|
||||
|
||||
except Exception: return base_level
|
||||
|
||||
|
||||
class RemoteExecutionError(Exception):
|
||||
def __init__(self, msg): self.msg = msg
|
||||
|
||||
|
||||
async def _initiate_link(configdir, rnsconfigdir, identitypath=None, verbosity=0, quietness=0, noid=False, destination=None,
|
||||
timeout=RNS.Transport.PATH_REQUEST_TIMEOUT):
|
||||
global _identity, _reticulum, _link, _destination, _remote_exec_grace
|
||||
|
||||
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2
|
||||
if len(destination) != dest_len:
|
||||
raise RemoteExecutionError(
|
||||
"Allowed destination length is invalid, must be {hex} hexadecimal characters ({byte} bytes).".format(
|
||||
hex=dest_len, byte=dest_len // 2))
|
||||
try:
|
||||
destination_hash = bytes.fromhex(destination)
|
||||
except Exception as e:
|
||||
raise RemoteExecutionError("Invalid destination entered. Check your input.")
|
||||
|
||||
if _reticulum is None:
|
||||
targetloglevel = compute_target_rns_loglevel(verbosity, quietness, RNS.LOG_ERROR)
|
||||
RNS.logfile = os.path.join(configdir, "logfile")
|
||||
_reticulum = RNS.Reticulum(configdir=rnsconfigdir, loglevel=targetloglevel, logdest=RNS.LOG_FILE)
|
||||
|
||||
if _identity is None:
|
||||
_identity = rnsh.rnsh.prepare_identity(identitypath)
|
||||
|
||||
if not RNS.Transport.has_path(destination_hash):
|
||||
RNS.Transport.request_path(destination_hash)
|
||||
RNS.log(f"Requesting path...", RNS.LOG_INFO)
|
||||
if not await _spin(until=lambda: RNS.Transport.has_path(destination_hash), msg="Requesting path...",
|
||||
timeout=timeout, quiet=quietness > 0):
|
||||
raise RemoteExecutionError("Path not found")
|
||||
|
||||
if _destination is None:
|
||||
listener_identity = RNS.Identity.recall(destination_hash)
|
||||
_destination = RNS.Destination(
|
||||
listener_identity,
|
||||
RNS.Destination.OUT,
|
||||
RNS.Destination.SINGLE,
|
||||
rnsh.rnsh.APP_NAME
|
||||
)
|
||||
|
||||
if _link is None or _link.status == RNS.Link.PENDING:
|
||||
RNS.log("No link", RNS.LOG_DEBUG)
|
||||
_link = RNS.Link(_destination)
|
||||
_link.did_identify = False
|
||||
|
||||
_link.set_link_closed_callback(_client_link_closed)
|
||||
|
||||
RNS.log(f"Establishing link...", RNS.LOG_VERBOSE)
|
||||
if not await _spin(until=lambda: _link.status == RNS.Link.ACTIVE, msg="Establishing link...",
|
||||
timeout=timeout, quiet=quietness > 0):
|
||||
raise RemoteExecutionError("Could not establish link with " + RNS.prettyhexrep(destination_hash))
|
||||
|
||||
RNS.log("Have link", RNS.LOG_DEBUG)
|
||||
if not noid and not _link.did_identify:
|
||||
# Delay a tiny bit to allow listener to fully enter WAIT_IDENT state
|
||||
await asyncio.sleep(min(1, _link.rtt * 1.1 + 0.05))
|
||||
_link.identify(_identity)
|
||||
_link.did_identify = True
|
||||
|
||||
|
||||
async def _handle_error(errmsg: RNS.MessageBase):
|
||||
if isinstance(errmsg, protocol.ErrorMessage):
|
||||
with contextlib.suppress(Exception):
|
||||
if _link and _link.status == RNS.Link.ACTIVE:
|
||||
_link.teardown()
|
||||
await asyncio.sleep(0.1)
|
||||
raise RemoteExecutionError(f"Remote error: {errmsg.msg}")
|
||||
|
||||
|
||||
async def initiate(configdir: str, rnsconfigdir:str, identitypath: str, verbosity: int, quietness: int, noid: bool, destination: str,
|
||||
timeout: float, command: [str] | None = None):
|
||||
global _finished, _link
|
||||
with process.TTYRestorer(sys.stdin.fileno()) as ttyRestorer:
|
||||
loop = asyncio.get_running_loop()
|
||||
state = InitiatorState.IS_INITIAL
|
||||
data_buffer = bytearray(sys.stdin.buffer.read()) if not os.isatty(sys.stdin.fileno()) else bytearray()
|
||||
line_buffer = bytearray()
|
||||
|
||||
await _initiate_link(configdir=configdir,
|
||||
rnsconfigdir=rnsconfigdir,
|
||||
identitypath=identitypath,
|
||||
verbosity=verbosity,
|
||||
quietness=quietness,
|
||||
noid=noid,
|
||||
destination=destination,
|
||||
timeout=timeout)
|
||||
|
||||
if not _link or _link.status not in [RNS.Link.ACTIVE, RNS.Link.PENDING]:
|
||||
return 255
|
||||
|
||||
state = InitiatorState.IS_LINKED
|
||||
outlet = session.RNSOutlet(_link)
|
||||
channel = _link.get_channel()
|
||||
protocol.register_message_types(channel)
|
||||
channel.add_message_handler(_client_message_handler)
|
||||
|
||||
# Next step after linking and identifying: send version
|
||||
# if not await _spin(lambda: messenger.is_outlet_ready(outlet), timeout=5, quiet=quietness > 0):
|
||||
# print("Error bringing up link")
|
||||
# return 253
|
||||
|
||||
channel.send(protocol.VersionInfoMessage())
|
||||
try:
|
||||
vm = _pq.get(timeout=max(outlet.rtt * 20, 5))
|
||||
await _handle_error(vm)
|
||||
if not isinstance(vm, protocol.VersionInfoMessage):
|
||||
raise Exception("Invalid message received")
|
||||
RNS.log(f"Server version info: sw {vm.sw_version} prot {vm.protocol_version}", RNS.LOG_DEBUG)
|
||||
state = InitiatorState.IS_RUNNING
|
||||
except queue.Empty:
|
||||
print("Protocol error")
|
||||
return 254
|
||||
|
||||
winch = False
|
||||
def sigwinch_handler():
|
||||
nonlocal winch
|
||||
winch = True
|
||||
|
||||
esc = False
|
||||
pre_esc = True
|
||||
line_mode = False
|
||||
line_flush = False
|
||||
blind_write_count = 0
|
||||
flush_chars = ["\x01", "\x03", "\x04", "\x05", "\x0c", "\x11", "\x13", "\x15", "\x19", "\t", "\x1A", "\x1B"]
|
||||
def handle_escape(b):
|
||||
nonlocal line_mode
|
||||
if b == "?":
|
||||
os.write(1, "\n\r\n\rSupported rnsh escape sequences:".encode("utf-8"))
|
||||
os.write(1, "\n\r ~~ Send the escape character by typing it twice".encode("utf-8"))
|
||||
os.write(1, "\n\r ~. Terminate session and exit immediately".encode("utf-8"))
|
||||
os.write(1, "\n\r ~L Toggle line-interactive mode".encode("utf-8"))
|
||||
os.write(1, "\n\r ~? Display this quick reference\n\r".encode("utf-8"))
|
||||
os.write(1, "\n\r(Escape sequences are only recognized immediately after newline)\n\r".encode("utf-8"))
|
||||
return None
|
||||
elif b == ".":
|
||||
_link.teardown()
|
||||
return None
|
||||
elif b == "L":
|
||||
line_mode = not line_mode
|
||||
if line_mode:
|
||||
os.write(1, "\n\rLine-interactive mode enabled\n\r".encode("utf-8"))
|
||||
else:
|
||||
os.write(1, "\n\rLine-interactive mode disabled\n\r".encode("utf-8"))
|
||||
return None
|
||||
|
||||
return b
|
||||
|
||||
stdin_eof = False
|
||||
def stdin():
|
||||
nonlocal stdin_eof, pre_esc, esc, line_mode
|
||||
nonlocal line_flush, blind_write_count
|
||||
try:
|
||||
in_data = process.tty_read(sys.stdin.fileno())
|
||||
if in_data is not None:
|
||||
data = bytearray()
|
||||
for b in bytes(in_data):
|
||||
c = chr(b)
|
||||
if c == "\r":
|
||||
pre_esc = True
|
||||
line_flush = True
|
||||
data.append(b)
|
||||
elif line_mode and c in flush_chars:
|
||||
pre_esc = False
|
||||
line_flush = True
|
||||
data.append(b)
|
||||
elif line_mode and (c == "\b" or c == "\x7f"):
|
||||
pre_esc = False
|
||||
if len(line_buffer)>0:
|
||||
line_buffer.pop(-1)
|
||||
blind_write_count -= 1
|
||||
os.write(1, "\b \b".encode("utf-8"))
|
||||
elif pre_esc == True and c == "~":
|
||||
pre_esc = False
|
||||
esc = True
|
||||
elif esc == True:
|
||||
ret = handle_escape(c)
|
||||
if ret != None:
|
||||
if ret != "~":
|
||||
data.append(ord("~"))
|
||||
data.append(ord(ret))
|
||||
esc = False
|
||||
else:
|
||||
pre_esc = False
|
||||
data.append(b)
|
||||
|
||||
if not line_mode:
|
||||
data_buffer.extend(data)
|
||||
else:
|
||||
line_buffer.extend(data)
|
||||
if line_flush:
|
||||
data_buffer.extend(line_buffer)
|
||||
line_buffer.clear()
|
||||
os.write(1, ("\b \b"*blind_write_count).encode("utf-8"))
|
||||
line_flush = False
|
||||
blind_write_count = 0
|
||||
else:
|
||||
os.write(1, data)
|
||||
blind_write_count += len(data)
|
||||
|
||||
except EOFError:
|
||||
if os.isatty(0):
|
||||
data_buffer.extend(process.CTRL_D)
|
||||
stdin_eof = True
|
||||
process.tty_unset_reader_callbacks(sys.stdin.fileno())
|
||||
|
||||
process.tty_add_reader_callback(sys.stdin.fileno(), stdin)
|
||||
|
||||
tcattr = None
|
||||
rows, cols, hpix, vpix = (None, None, None, None)
|
||||
try:
|
||||
tcattr = termios.tcgetattr(0)
|
||||
rows, cols, hpix, vpix = process.tty_get_winsize(0)
|
||||
except:
|
||||
try:
|
||||
tcattr = termios.tcgetattr(1)
|
||||
rows, cols, hpix, vpix = process.tty_get_winsize(1)
|
||||
except:
|
||||
try:
|
||||
tcattr = termios.tcgetattr(2)
|
||||
rows, cols, hpix, vpix = process.tty_get_winsize(2)
|
||||
except:
|
||||
pass
|
||||
|
||||
await _spin(lambda: channel.is_ready_to_send(), "Waiting for channel...", 1, quietness > 0)
|
||||
channel.send(protocol.ExecuteCommandMesssage(cmdline=command,
|
||||
pipe_stdin=not os.isatty(0),
|
||||
pipe_stdout=not os.isatty(1),
|
||||
pipe_stderr=not os.isatty(2),
|
||||
tcflags=tcattr,
|
||||
term=os.environ.get("TERM", None),
|
||||
rows=rows,
|
||||
cols=cols,
|
||||
hpix=hpix,
|
||||
vpix=vpix))
|
||||
|
||||
loop.add_signal_handler(signal.SIGWINCH, sigwinch_handler)
|
||||
_finished = asyncio.Event()
|
||||
loop.add_signal_handler(signal.SIGINT, functools.partial(_sigint_handler, signal.SIGINT, loop))
|
||||
loop.add_signal_handler(signal.SIGTERM, functools.partial(_sigint_handler, signal.SIGTERM, loop))
|
||||
mdu = _link.MDU - 16
|
||||
sent_eof = False
|
||||
last_winch = time.time()
|
||||
sleeper = helpers.SleepRate(0.01)
|
||||
processed = False
|
||||
while not await _check_finished() and state in [InitiatorState.IS_RUNNING]:
|
||||
try:
|
||||
try:
|
||||
message = _pq.get(timeout=sleeper.next_sleep_time() if not processed else 0.0005)
|
||||
await _handle_error(message)
|
||||
processed = True
|
||||
if isinstance(message, protocol.StreamDataMessage):
|
||||
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDOUT:
|
||||
if message.data and len(message.data) > 0:
|
||||
ttyRestorer.raw()
|
||||
RNS.log(f"stdout: {message.data}", RNS.LOG_DEBUG)
|
||||
os.write(1, message.data)
|
||||
sys.stdout.flush()
|
||||
if message.eof:
|
||||
os.close(1)
|
||||
if message.stream_id == protocol.StreamDataMessage.STREAM_ID_STDERR:
|
||||
if message.data and len(message.data) > 0:
|
||||
ttyRestorer.raw()
|
||||
RNS.log(f"stdout: {message.data}", RNS.LOG_DEBUG)
|
||||
os.write(2, message.data)
|
||||
sys.stderr.flush()
|
||||
if message.eof:
|
||||
os.close(2)
|
||||
elif isinstance(message, protocol.CommandExitedMessage):
|
||||
RNS.log(f"received return code {message.return_code}, exiting", RNS.LOG_DEBUG)
|
||||
return message.return_code
|
||||
elif isinstance(message, protocol.ErrorMessage):
|
||||
RNS.log(f"Remote error: {message.data}", RNS.LOG_ERROR)
|
||||
if message.fatal:
|
||||
_link.teardown()
|
||||
return 200
|
||||
|
||||
except queue.Empty:
|
||||
processed = False
|
||||
|
||||
if channel.is_ready_to_send():
|
||||
def compress_adaptive(buf: bytes):
|
||||
comp_tries = RNS.RawChannelWriter.COMPRESSION_TRIES
|
||||
comp_try = 1
|
||||
comp_success = False
|
||||
|
||||
chunk_len = len(buf)
|
||||
if chunk_len > RNS.RawChannelWriter.MAX_CHUNK_LEN:
|
||||
chunk_len = RNS.RawChannelWriter.MAX_CHUNK_LEN
|
||||
chunk_segment = None
|
||||
|
||||
chunk_segment = None
|
||||
max_data_len = channel.mdu - protocol.StreamDataMessage.OVERHEAD
|
||||
while chunk_len > 32 and comp_try < comp_tries:
|
||||
chunk_segment_length = int(chunk_len/comp_try)
|
||||
compressed_chunk = bz2.compress(buf[:chunk_segment_length])
|
||||
compressed_length = len(compressed_chunk)
|
||||
if compressed_length < max_data_len and compressed_length < chunk_segment_length:
|
||||
comp_success = True
|
||||
break
|
||||
else:
|
||||
comp_try += 1
|
||||
|
||||
if comp_success:
|
||||
diff = max_data_len - len(compressed_chunk)
|
||||
chunk = compressed_chunk
|
||||
processed_length = chunk_segment_length
|
||||
else:
|
||||
chunk = bytes(buf[:max_data_len])
|
||||
processed_length = len(chunk)
|
||||
|
||||
return comp_success, processed_length, chunk
|
||||
|
||||
comp_success, processed_length, chunk = compress_adaptive(data_buffer)
|
||||
stdin = chunk
|
||||
data_buffer = data_buffer[processed_length:]
|
||||
eof = not sent_eof and stdin_eof and len(stdin) == 0
|
||||
if len(stdin) > 0 or eof:
|
||||
channel.send(protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDIN, stdin, eof, comp_success))
|
||||
sent_eof = eof
|
||||
processed = True
|
||||
|
||||
# send window change, but rate limited
|
||||
if winch and time.time() - last_winch > _link.rtt * 25:
|
||||
last_winch = time.time()
|
||||
winch = False
|
||||
with contextlib.suppress(Exception):
|
||||
r, c, h, v = process.tty_get_winsize(0)
|
||||
channel.send(protocol.WindowSizeMessage(r, c, h, v))
|
||||
processed = True
|
||||
except RemoteExecutionError as e:
|
||||
print(e.msg)
|
||||
return 255
|
||||
except Exception as ex:
|
||||
print(f"Client exception: {ex}")
|
||||
if _link and _link.status != RNS.Link.CLOSED:
|
||||
_link.teardown()
|
||||
return 127
|
||||
|
||||
RNS.log("Main loop done", RNS.LOG_DEBUG)
|
||||
return 0
|
||||
219
RNS/Utilities/rnsh/listener.py
Normal file
219
RNS/Utilities/rnsh/listener.py
Normal file
|
|
@ -0,0 +1,219 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2016-2022 Mark Qvist / unsigned.io
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
import queue
|
||||
import shlex
|
||||
import signal
|
||||
import sys
|
||||
import termios
|
||||
import threading
|
||||
import time
|
||||
import tty
|
||||
from typing import Callable, TypeVar
|
||||
import RNS
|
||||
import rnsh.exception as exception
|
||||
import rnsh.process as process
|
||||
import rnsh.retry as retry
|
||||
import rnsh.session as session
|
||||
import re
|
||||
import contextlib
|
||||
import rnsh.args
|
||||
import pwd
|
||||
import rnsh.protocol as protocol
|
||||
import rnsh.helpers as helpers
|
||||
import rnsh.rnsh
|
||||
|
||||
|
||||
_identity = None
|
||||
_reticulum = None
|
||||
_allow_all = False
|
||||
_allowed_file = None
|
||||
_allowed_identity_hashes = []
|
||||
_allowed_file_identity_hashes = []
|
||||
_cmd: [str] | None = None
|
||||
DATA_AVAIL_MSG = "data available"
|
||||
_finished: asyncio.Event = None
|
||||
_retry_timer: retry.RetryThread | None = None
|
||||
_destination: RNS.Destination | None = None
|
||||
_loop: asyncio.AbstractEventLoop | None = None
|
||||
_no_remote_command = True
|
||||
_remote_cmd_as_args = False
|
||||
|
||||
|
||||
async def _check_finished(timeout: float = 0):
|
||||
return await process.event_wait(_finished, timeout=timeout)
|
||||
|
||||
|
||||
def _sigint_handler(sig, loop):
|
||||
global _finished
|
||||
RNS.log(f"Signal: {signal.Signals(sig).name}", RNS.LOG_DEBUG)
|
||||
if _finished is not None: _finished.set()
|
||||
else: raise KeyboardInterrupt()
|
||||
|
||||
def _reload_allowed_file():
|
||||
global _allowed_file, _allowed_file_identity_hashes
|
||||
if _allowed_file != None:
|
||||
try:
|
||||
with open(_allowed_file, "r") as file:
|
||||
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2
|
||||
added = 0
|
||||
line = 0
|
||||
_allowed_file_identity_hashes = []
|
||||
for allow in file.read().replace("\r", "").split("\n"):
|
||||
line += 1
|
||||
if len(allow) == dest_len:
|
||||
try:
|
||||
destination_hash = bytes.fromhex(allow)
|
||||
_allowed_file_identity_hashes.append(destination_hash)
|
||||
added += 1
|
||||
except Exception:
|
||||
RNS.log(f"Discarded invalid Identity hash in {_allowed_file} at line {line}", RNS.LOG_DEBUG)
|
||||
|
||||
ms = "y" if added == 1 else "ies"
|
||||
RNS.log(f"Loaded {added} allowed identit{ms} from "+str(_allowed_file), RNS.LOG_DEBUG)
|
||||
|
||||
except Exception as e: RNS.log(f"Error while reloading allowed indetities file: {e}", RNS.LOG_ERROR)
|
||||
|
||||
def compute_target_rns_loglevel(verbosity: int, quietness: int, base_level: int = RNS.LOG_INFO) -> int:
|
||||
try:
|
||||
target = int(base_level) + int(verbosity) - int(quietness)
|
||||
if target < RNS.LOG_CRITICAL: target = RNS.LOG_CRITICAL
|
||||
if target > RNS.LOG_DEBUG: target = RNS.LOG_DEBUG
|
||||
return target
|
||||
|
||||
except Exception: return base_level
|
||||
|
||||
async def listen(configdir, rnsconfigdir, command, identitypath=None, service_name=None, verbosity=0, quietness=0, allowed=None,
|
||||
allowed_file=None, disable_auth=None, announce_period=900, no_remote_command=True, remote_cmd_as_args=False,
|
||||
loop: asyncio.AbstractEventLoop = None):
|
||||
global _identity, _allow_all, _allowed_identity_hashes, _allowed_file, _allowed_file_identity_hashes
|
||||
global _reticulum, _cmd, _destination, _no_remote_command, _remote_cmd_as_args, _finished
|
||||
|
||||
if not loop: loop = asyncio.get_running_loop()
|
||||
if service_name is None or len(service_name) == 0:
|
||||
service_name = "default"
|
||||
|
||||
RNS.log(f"Using service name {service_name}", RNS.LOG_INFO)
|
||||
|
||||
# More -v should increase verbosity (higher RNS.loglevel); -q should decrease it
|
||||
targetloglevel = compute_target_rns_loglevel(verbosity, quietness, RNS.LOG_INFO)
|
||||
_reticulum = RNS.Reticulum(configdir=rnsconfigdir, loglevel=targetloglevel)
|
||||
_identity = rnsh.rnsh.prepare_identity(identitypath, service_name)
|
||||
_destination = RNS.Destination(_identity, RNS.Destination.IN, RNS.Destination.SINGLE, rnsh.rnsh.APP_NAME)
|
||||
|
||||
RNS.log(f"rnsh listening for commands on {RNS.prettyhexrep(_destination.hash)}", RNS.LOG_NOTICE)
|
||||
|
||||
_cmd = command
|
||||
if _cmd is None or len(_cmd) == 0:
|
||||
shell = None
|
||||
try: shell = pwd.getpwuid(os.getuid()).pw_shell
|
||||
except Exception as e: RNS.log(f"Error looking up shell: {e}", RNS.LOG_ERROR)
|
||||
RNS.log(f"Using {shell} for default command.", RNS.LOG_INFO)
|
||||
|
||||
# Ensure a sane shell default. Fall back to /bin/sh if lookup fails.
|
||||
if not shell or len(shell) == 0: shell = "/bin/sh"
|
||||
_cmd = [shell]
|
||||
|
||||
else: RNS.log(f"Using command {shlex.join(_cmd)}", RNS.LOG_INFO)
|
||||
|
||||
_no_remote_command = no_remote_command
|
||||
session.ListenerSession.allow_remote_command = not no_remote_command
|
||||
_remote_cmd_as_args = remote_cmd_as_args
|
||||
if (_cmd is None or len(_cmd) == 0 or _cmd[0] is None or len(_cmd[0]) == 0) \
|
||||
and (_no_remote_command or _remote_cmd_as_args):
|
||||
raise Exception(f"Unable to look up shell for {os.getlogin}, cannot proceed with -A or -C and no <program>.")
|
||||
|
||||
session.ListenerSession.default_command = _cmd
|
||||
session.ListenerSession.remote_cmd_as_args = _remote_cmd_as_args
|
||||
|
||||
if disable_auth:
|
||||
_allow_all = True
|
||||
session.ListenerSession.allow_all = True
|
||||
else:
|
||||
if allowed_file is not None:
|
||||
_allowed_file = allowed_file
|
||||
_reload_allowed_file()
|
||||
|
||||
if allowed is not None:
|
||||
for a in allowed:
|
||||
try:
|
||||
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH // 8) * 2
|
||||
if len(a) != dest_len:
|
||||
raise ValueError(
|
||||
"Allowed destination length is invalid, must be {hex} hexadecimal " +
|
||||
"characters ({byte} bytes).".format(
|
||||
hex=dest_len, byte=dest_len // 2))
|
||||
try:
|
||||
destination_hash = bytes.fromhex(a)
|
||||
_allowed_identity_hashes.append(destination_hash)
|
||||
session.ListenerSession.allowed_identity_hashes.append(destination_hash)
|
||||
except Exception:
|
||||
raise ValueError("Invalid destination entered. Check your input.")
|
||||
|
||||
except Exception as e:
|
||||
RNS.log(f"Unhandled error: {e}", RNS.LOG_ERROR)
|
||||
RNS.trace_exception(e)
|
||||
exit(1)
|
||||
|
||||
if (len(_allowed_identity_hashes) < 1 and len(_allowed_file_identity_hashes) < 1) and not disable_auth:
|
||||
RNS.log("Warning: No allowed identities configured, rnsh will not accept any connections!", RNS.LOG_WARNING)
|
||||
|
||||
def link_established(lnk: RNS.Link):
|
||||
_reload_allowed_file()
|
||||
session.ListenerSession.allowed_file_identity_hashes = _allowed_file_identity_hashes
|
||||
session.ListenerSession(session.RNSOutlet.get_outlet(lnk), lnk.get_channel(), loop)
|
||||
_destination.set_link_established_callback(link_established)
|
||||
|
||||
_finished = asyncio.Event()
|
||||
signal.signal(signal.SIGINT, _sigint_handler)
|
||||
|
||||
if announce_period is not None: _destination.announce()
|
||||
|
||||
last_announce = time.time()
|
||||
sleeper = helpers.SleepRate(0.01)
|
||||
|
||||
try:
|
||||
while not await _check_finished():
|
||||
if announce_period and 0 < announce_period < time.time() - last_announce:
|
||||
last_announce = time.time()
|
||||
_destination.announce()
|
||||
if len(session.ListenerSession.sessions) > 0:
|
||||
# no sleep if there's work to do
|
||||
if not await session.ListenerSession.pump_all():
|
||||
await sleeper.sleep_async()
|
||||
else:
|
||||
await asyncio.sleep(0.25)
|
||||
finally:
|
||||
RNS.log("Shutting down", RNS.LOG_NOTICE)
|
||||
await session.ListenerSession.terminate_all("Shutting down")
|
||||
await asyncio.sleep(1)
|
||||
links_still_active = list(filter(lambda l: l.status != RNS.Link.CLOSED, _destination.links))
|
||||
for link in links_still_active:
|
||||
if link.status not in [RNS.Link.CLOSED]:
|
||||
link.teardown()
|
||||
await asyncio.sleep(0.01)
|
||||
12
RNS/Utilities/rnsh/loop.py
Normal file
12
RNS/Utilities/rnsh/loop.py
Normal file
|
|
@ -0,0 +1,12 @@
|
|||
import asyncio
|
||||
import functools
|
||||
from typing import Callable
|
||||
|
||||
def sig_handler_sys_to_loop(handler: Callable[[int, any], None]) -> Callable[[int, asyncio.AbstractEventLoop], None]:
|
||||
def wrapped(cb: Callable[[int, any], None], signal: int, loop: asyncio.AbstractEventLoop): cb(signal, None)
|
||||
return functools.partial(wrapped, handler)
|
||||
|
||||
def loop_set_signal(sig, handler: Callable[[int, asyncio.AbstractEventLoop], None], loop: asyncio.AbstractEventLoop = None):
|
||||
if loop is None: loop = asyncio.get_running_loop()
|
||||
loop.remove_signal_handler(sig)
|
||||
loop.add_signal_handler(sig, functools.partial(handler, sig, loop))
|
||||
773
RNS/Utilities/rnsh/process.py
Normal file
773
RNS/Utilities/rnsh/process.py
Normal file
|
|
@ -0,0 +1,773 @@
|
|||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 Aaron Heise
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
import asyncio
|
||||
import contextlib
|
||||
import copy
|
||||
import errno
|
||||
import fcntl
|
||||
import functools
|
||||
import os
|
||||
import pty
|
||||
import select
|
||||
import signal
|
||||
import struct
|
||||
import sys
|
||||
import termios
|
||||
import threading
|
||||
import tty
|
||||
import types
|
||||
import typing
|
||||
import RNS
|
||||
|
||||
import rnsh.exception as exception
|
||||
|
||||
CTRL_C = "\x03".encode("utf-8")
|
||||
CTRL_D = "\x04".encode("utf-8")
|
||||
|
||||
def tty_add_reader_callback(fd: int, callback: callable, loop: asyncio.AbstractEventLoop = None):
|
||||
"""
|
||||
Add an async reader callback for a tty file descriptor.
|
||||
|
||||
Example usage:
|
||||
|
||||
def reader():
|
||||
data = tty_read(fd)
|
||||
# do something with data
|
||||
|
||||
tty_add_reader_callback(self._child_fd, reader, self._loop)
|
||||
|
||||
:param fd: file descriptor
|
||||
:param callback: callback function
|
||||
:param loop: asyncio event loop to which the reader should be added. If None, use the currently-running loop.
|
||||
"""
|
||||
if loop is None:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.add_reader(fd, callback)
|
||||
|
||||
|
||||
def tty_read(fd: int) -> bytes:
|
||||
"""
|
||||
Read available bytes from a tty file descriptor. When used in a callback added to a file descriptor using
|
||||
tty_add_reader_callback(...), this function creates a solution for non-blocking reads from ttys.
|
||||
:param fd: tty file descriptor
|
||||
:return: bytes read
|
||||
"""
|
||||
if fd_is_closed(fd):
|
||||
raise EOFError
|
||||
|
||||
try:
|
||||
run = True
|
||||
result = bytearray()
|
||||
while not fd_is_closed(fd):
|
||||
ready, _, _ = select.select([fd], [], [], 0)
|
||||
if len(ready) == 0:
|
||||
break
|
||||
for f in ready:
|
||||
try:
|
||||
data = os.read(f, 4096)
|
||||
except OSError as e:
|
||||
if e.errno != errno.EIO and e.errno != errno.EWOULDBLOCK:
|
||||
raise
|
||||
else:
|
||||
if not data: # EOF
|
||||
if data is not None and len(data) > 0:
|
||||
result.extend(data)
|
||||
return result
|
||||
elif len(result) > 0:
|
||||
return result
|
||||
else:
|
||||
raise EOFError
|
||||
if data is not None and len(data) > 0:
|
||||
result.extend(data)
|
||||
return result
|
||||
|
||||
except EOFError: raise
|
||||
except Exception as e: RNS.log(f"TTY read error: {e}", RNS.LOG_ERROR)
|
||||
|
||||
|
||||
def tty_read_poll(fd: int) -> bytes:
|
||||
"""
|
||||
Read available bytes from a tty file descriptor. When used in a callback added to a file descriptor using
|
||||
tty_add_reader_callback(...), this function creates a solution for non-blocking reads from ttys.
|
||||
:param fd: tty file descriptor
|
||||
:return: bytes read
|
||||
"""
|
||||
if fd_is_closed(fd):
|
||||
raise EOFError
|
||||
|
||||
result = bytearray()
|
||||
try:
|
||||
flags = fcntl.fcntl(fd, fcntl.F_GETFL)
|
||||
fcntl.fcntl(fd, fcntl.F_SETFL, flags | os.O_NONBLOCK)
|
||||
while True:
|
||||
try:
|
||||
data = os.read(fd, 4096)
|
||||
if not data:
|
||||
# EOF
|
||||
if len(result) > 0:
|
||||
return result
|
||||
raise EOFError
|
||||
result.extend(data)
|
||||
# continue loop to drain
|
||||
except OSError as e:
|
||||
if e.errno in (errno.EWOULDBLOCK, errno.EAGAIN):
|
||||
break
|
||||
if e.errno == errno.EIO:
|
||||
if len(result) > 0:
|
||||
return result
|
||||
raise EOFError
|
||||
raise
|
||||
except EOFError: raise
|
||||
except Exception as e: RNS.log(f"TTY read error: {e}", RNS.LOG_ERROR)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def fd_is_closed(fd: int) -> bool:
|
||||
"""
|
||||
Check if file descriptor is closed
|
||||
:param fd: file descriptor
|
||||
:return: True if file descriptor is closed
|
||||
"""
|
||||
try:
|
||||
fcntl.fcntl(fd, fcntl.F_GETFL) < 0
|
||||
except OSError as ose:
|
||||
return ose.errno == errno.EBADF
|
||||
|
||||
|
||||
def tty_unset_reader_callbacks(fd: int, loop: asyncio.AbstractEventLoop = None):
|
||||
"""
|
||||
Remove async reader callbacks for file descriptor.
|
||||
:param fd: file descriptor
|
||||
:param loop: asyncio event loop from which to remove callbacks
|
||||
"""
|
||||
with exception.permit(SystemExit):
|
||||
if loop is None:
|
||||
loop = asyncio.get_running_loop()
|
||||
loop.remove_reader(fd)
|
||||
|
||||
|
||||
def tty_get_winsize(fd: int) -> [int, int, int, int]:
|
||||
"""
|
||||
Ge the window size of a tty.
|
||||
:param fd: file descriptor of tty
|
||||
:return: (rows, cols, h_pixels, v_pixels)
|
||||
"""
|
||||
packed = fcntl.ioctl(fd, termios.TIOCGWINSZ, struct.pack('HHHH', 0, 0, 0, 0))
|
||||
rows, cols, h_pixels, v_pixels = struct.unpack('HHHH', packed)
|
||||
return rows, cols, h_pixels, v_pixels
|
||||
|
||||
|
||||
def tty_set_winsize(fd: int, rows: int, cols: int, h_pixels: int, v_pixels: int):
|
||||
"""
|
||||
Set the window size on a tty.
|
||||
:param fd: file descriptor of tty
|
||||
:param rows: number of visible rows
|
||||
:param cols: number of visible columns
|
||||
:param h_pixels: number of visible horizontal pixels
|
||||
:param v_pixels: number of visible vertical pixels
|
||||
"""
|
||||
if fd < 0:
|
||||
return
|
||||
packed = struct.pack('HHHH', rows, cols, h_pixels, v_pixels)
|
||||
fcntl.ioctl(fd, termios.TIOCSWINSZ, packed)
|
||||
|
||||
|
||||
def process_exists(pid) -> bool:
|
||||
"""
|
||||
Check For the existence of a unix pid.
|
||||
:param pid: process id to check
|
||||
:return: True if process exists
|
||||
"""
|
||||
try:
|
||||
os.kill(pid, 0)
|
||||
except OSError:
|
||||
return False
|
||||
else:
|
||||
return True
|
||||
|
||||
|
||||
class TTYRestorer(contextlib.AbstractContextManager):
|
||||
# Indexes of flags within the attrs array
|
||||
ATTR_IDX_IFLAG = 0
|
||||
ATTR_IDX_OFLAG = 1
|
||||
ATTR_IDX_CFLAG = 2
|
||||
ATTR_IDX_LFLAG = 4
|
||||
ATTR_IDX_CC = 5
|
||||
|
||||
def __init__(self, fd: int, suppress_logs=False):
|
||||
"""
|
||||
Saves termios attributes for a tty for later restoration.
|
||||
|
||||
The attributes are an array of values with the following meanings.
|
||||
|
||||
tcflag_t c_iflag; /* input modes */
|
||||
tcflag_t c_oflag; /* output modes */
|
||||
tcflag_t c_cflag; /* control modes */
|
||||
tcflag_t c_lflag; /* local modes */
|
||||
cc_t c_cc[NCCS]; /* special characters */
|
||||
|
||||
:param fd: file descriptor of tty
|
||||
"""
|
||||
self._fd = fd
|
||||
self._tattr = None
|
||||
self._suppress_logs = suppress_logs
|
||||
self._tattr = self.current_attr()
|
||||
if not self._tattr and not self._suppress_logs: RNS.log(f"Could not get attrs for fd {fd}", RNS.LOG_DEBUG)
|
||||
|
||||
def raw(self):
|
||||
"""
|
||||
Set raw mode on tty
|
||||
"""
|
||||
if self._fd is None:
|
||||
return
|
||||
with contextlib.suppress(termios.error):
|
||||
tty.setraw(self._fd, termios.TCSANOW)
|
||||
|
||||
def original_attr(self) -> [any]:
|
||||
return copy.deepcopy(self._tattr)
|
||||
|
||||
def current_attr(self) -> [any]:
|
||||
"""
|
||||
Get the current termios attributes for the wrapped fd.
|
||||
:return: attribute array
|
||||
"""
|
||||
if self._fd is None:
|
||||
return None
|
||||
|
||||
with contextlib.suppress(termios.error):
|
||||
return copy.deepcopy(termios.tcgetattr(self._fd))
|
||||
return None
|
||||
|
||||
def set_attr(self, attr: [any], when: int = termios.TCSADRAIN):
|
||||
"""
|
||||
Set termios attributes
|
||||
:param attr: attribute list to set
|
||||
:param when: when attributes should be applied (termios.TCSANOW, termios.TCSADRAIN, termios.TCSAFLUSH)
|
||||
"""
|
||||
if not attr or self._fd is None:
|
||||
return
|
||||
|
||||
with contextlib.suppress(termios.error):
|
||||
termios.tcsetattr(self._fd, when, attr)
|
||||
|
||||
def isatty(self):
|
||||
return os.isatty(self._fd) if self._fd is not None else None
|
||||
|
||||
def restore(self):
|
||||
"""
|
||||
Restore termios settings to state captured in constructor.
|
||||
"""
|
||||
self.set_attr(self._tattr, termios.TCSADRAIN)
|
||||
|
||||
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
|
||||
__traceback: types.TracebackType) -> bool:
|
||||
self.restore()
|
||||
return False #__exc_type is not None and issubclass(__exc_type, termios.error)
|
||||
|
||||
|
||||
def _task_from_event(evt: asyncio.Event, loop: asyncio.AbstractEventLoop = None):
|
||||
if not loop:
|
||||
loop = asyncio.get_running_loop()
|
||||
|
||||
#TODO: this is hacky
|
||||
async def wait():
|
||||
while not evt.is_set():
|
||||
await asyncio.sleep(0.1)
|
||||
return True
|
||||
|
||||
return loop.create_task(wait())
|
||||
|
||||
|
||||
class AggregateException(Exception):
|
||||
def __init__(self, inner_exceptions: [Exception]):
|
||||
super().__init__()
|
||||
self.inner_exceptions = inner_exceptions
|
||||
|
||||
def __str__(self):
|
||||
return "Multiple exceptions encountered: \n\n" + "\n\n".join(map(lambda e: str(e), self.inner_exceptions))
|
||||
|
||||
|
||||
async def event_wait_any(evts: [asyncio.Event], timeout: float = None) -> (any, any):
|
||||
tasks = list(map(lambda evt: (evt, _task_from_event(evt)), evts))
|
||||
try:
|
||||
finished, unfinished = await asyncio.wait(map(lambda t: t[1], tasks),
|
||||
timeout=timeout,
|
||||
return_when=asyncio.FIRST_COMPLETED)
|
||||
|
||||
if len(unfinished) > 0:
|
||||
for task in unfinished:
|
||||
task.cancel()
|
||||
await asyncio.wait(unfinished)
|
||||
|
||||
exceptions = []
|
||||
|
||||
for f in finished:
|
||||
ex = f.exception()
|
||||
if ex and not isinstance(ex, asyncio.CancelledError) and not isinstance(ex, TimeoutError):
|
||||
exceptions.append(ex)
|
||||
|
||||
if len(exceptions) > 0:
|
||||
raise AggregateException(exceptions)
|
||||
|
||||
return next(map(lambda t: next(map(lambda tt: tt[0], tasks)), finished), None)
|
||||
finally:
|
||||
unfinished = []
|
||||
for task in map(lambda t: t[1], tasks):
|
||||
if task.done():
|
||||
if not task.cancelled():
|
||||
task.exception()
|
||||
else:
|
||||
task.cancel()
|
||||
unfinished.append(task)
|
||||
if len(unfinished) > 0:
|
||||
await asyncio.wait(unfinished)
|
||||
|
||||
|
||||
async def event_wait(evt: asyncio.Event, timeout: float) -> bool:
|
||||
"""
|
||||
Wait for event to be set, or timeout to expire.
|
||||
:param evt: asyncio.Event to wait on
|
||||
:param timeout: maximum number of seconds to wait.
|
||||
:return: True if event was set, False if timeout expired
|
||||
"""
|
||||
await event_wait_any([evt], timeout=timeout)
|
||||
return evt.is_set()
|
||||
|
||||
|
||||
def _launch_child(cmd_line: list[str], env: dict[str, str], stdin_is_pipe: bool, stdout_is_pipe: bool,
|
||||
stderr_is_pipe: bool) -> tuple[int, int, int, int]:
|
||||
# Set up PTY and/or pipes
|
||||
child_fd = parent_fd = None
|
||||
if not (stdin_is_pipe and stdout_is_pipe and stderr_is_pipe):
|
||||
parent_fd, child_fd = pty.openpty()
|
||||
child_stdin, parent_stdin = (os.pipe() if stdin_is_pipe else (child_fd, parent_fd))
|
||||
parent_stdout, child_stdout = (os.pipe() if stdout_is_pipe else (parent_fd, child_fd))
|
||||
parent_stderr, child_stderr = (os.pipe() if stderr_is_pipe else (parent_fd, child_fd))
|
||||
|
||||
# Fork
|
||||
pid = os.fork()
|
||||
|
||||
if pid == 0:
|
||||
try:
|
||||
# We are in the child process, so close all open sockets and pipes except for the PTY and/or pipes
|
||||
max_fd = os.sysconf("SC_OPEN_MAX")
|
||||
for fd in range(3, max_fd):
|
||||
if fd not in (child_stdin, child_stdout, child_stderr):
|
||||
try:
|
||||
os.close(fd)
|
||||
except OSError:
|
||||
pass
|
||||
|
||||
# Set up PTY and/or pipes
|
||||
os.dup2(child_stdin, 0)
|
||||
os.dup2(child_stdout, 1)
|
||||
os.dup2(child_stderr, 2)
|
||||
# Make PTY controlling if necessary so that CTRL_C/CTRL_D behave as expected
|
||||
if child_fd is not None:
|
||||
os.setsid()
|
||||
try:
|
||||
tty_fd = 0 if not stdin_is_pipe else (1 if not stdout_is_pipe else 2)
|
||||
# Set controlling TTY for this session
|
||||
fcntl.ioctl(tty_fd, termios.TIOCSCTTY, 0)
|
||||
except Exception:
|
||||
pass
|
||||
# Ensure the child is the foreground process group for the TTY
|
||||
try:
|
||||
os.setpgid(0, 0)
|
||||
pgid = os.getpgrp()
|
||||
import struct as _struct
|
||||
fcntl.ioctl(tty_fd, termios.TIOCSPGRP, _struct.pack('i', pgid))
|
||||
except Exception:
|
||||
pass
|
||||
# Ensure canonical input with signals and local echo enabled
|
||||
try:
|
||||
tty_fd = 0 if not stdin_is_pipe else (1 if not stdout_is_pipe else 2)
|
||||
attrs = termios.tcgetattr(tty_fd)
|
||||
lflag = attrs[3]
|
||||
lflag |= termios.ICANON | termios.ISIG | termios.ECHO
|
||||
attrs[3] = lflag
|
||||
termios.tcsetattr(tty_fd, termios.TCSANOW, attrs)
|
||||
except Exception:
|
||||
pass
|
||||
|
||||
# Execute the command
|
||||
os.execvpe(cmd_line[0], cmd_line, env)
|
||||
except Exception as err:
|
||||
exc_type, exc_obj, exc_tb = sys.exc_info()
|
||||
fname = os.path.split(exc_tb.tb_frame.f_code.co_filename)[1]
|
||||
print(f"Unable to start {cmd_line[0]}: {err} ({fname}:{exc_tb.tb_lineno})")
|
||||
sys.stdout.flush()
|
||||
# don't let any other modules get in our way, do an immediate silent exit.
|
||||
os._exit(255)
|
||||
|
||||
else:
|
||||
# We are in the parent process, so close the child-side of the PTY and/or pipes
|
||||
if child_fd is not None:
|
||||
os.close(child_fd)
|
||||
if child_stdin != child_fd:
|
||||
os.close(child_stdin)
|
||||
if child_stdout != child_fd:
|
||||
os.close(child_stdout)
|
||||
if child_stderr != child_fd:
|
||||
os.close(child_stderr)
|
||||
# # Close the write end of the pipe if a pipe is used for standard input
|
||||
# if not stdin_is_pipe:
|
||||
# os.close(parent_stdin)
|
||||
# Return the child PID and the file descriptors for the PTY and/or pipes
|
||||
return pid, parent_stdin, parent_stdout, parent_stderr
|
||||
|
||||
|
||||
class CallbackSubprocess:
|
||||
# time between checks of child process
|
||||
PROCESS_POLL_TIME: float = 0.1
|
||||
# Close pipes soon after process exit to avoid scheduling on closed event loops
|
||||
PROCESS_PIPE_TIME: int = 1
|
||||
|
||||
def __init__(self, argv: [str], env: dict, loop: asyncio.AbstractEventLoop, stdout_callback: callable,
|
||||
stderr_callback: callable, terminated_callback: callable, stdin_is_pipe: bool, stdout_is_pipe: bool,
|
||||
stderr_is_pipe: bool):
|
||||
"""
|
||||
Fork a child process and generate callbacks with output from the process.
|
||||
:param argv: the command line, tokenized. The first element must be the absolute path to an executable file.
|
||||
:param env: environment variables to override
|
||||
:param loop: the asyncio event loop to use
|
||||
:param stdout_callback: callback for data, e.g. def callback(data:bytes) -> None
|
||||
:param terminated_callback: callback for termination/return code, e.g. def callback(return_code:int) -> None
|
||||
"""
|
||||
assert loop is not None, "loop should not be None"
|
||||
assert stdout_callback is not None, "stdout_callback should not be None"
|
||||
assert terminated_callback is not None, "terminated_callback should not be None"
|
||||
|
||||
self._command: [str] = argv
|
||||
self._env = env or {}
|
||||
self._loop = loop
|
||||
self._stdout_cb = stdout_callback
|
||||
self._stderr_cb = stderr_callback
|
||||
self._terminated_cb = terminated_callback
|
||||
self._pid: int = None
|
||||
self._child_stdin: int = None
|
||||
self._child_stdout: int = None
|
||||
self._child_stderr: int = None
|
||||
self._return_code: int = None
|
||||
self._stdout_eof: bool = False
|
||||
self._stderr_eof: bool = False
|
||||
self._stdin_is_pipe = stdin_is_pipe
|
||||
self._stdout_is_pipe = stdout_is_pipe
|
||||
self._stderr_is_pipe = stderr_is_pipe
|
||||
self._at_line_start: bool = True
|
||||
self._tty_line_buffer: bytearray = bytearray()
|
||||
|
||||
def _ensure_pipes_closed(self):
|
||||
stdin = self._child_stdin
|
||||
stdout = self._child_stdout
|
||||
stderr = self._child_stderr
|
||||
fds = set(filter(lambda x: x is not None, list({stdin, stdout, stderr})))
|
||||
RNS.log(f"Queuing close of pipes for ended process (fds: {fds})", RNS.LOG_DEBUG)
|
||||
|
||||
def ensure_pipes_closed_inner():
|
||||
RNS.log(f"Ensuring pipes are closed (fds: {fds})", RNS.LOG_DEBUG)
|
||||
for fd in fds:
|
||||
RNS.log(f"Closing fd {fd}", RNS.LOG_DEBUG)
|
||||
with contextlib.suppress(OSError): tty_unset_reader_callbacks(fd)
|
||||
with contextlib.suppress(OSError): os.close(fd)
|
||||
|
||||
self._child_stdin = None
|
||||
self._child_stdout = None
|
||||
self._child_stderr = None
|
||||
|
||||
# Avoid scheduling on a closed loop
|
||||
if self._loop.is_closed(): ensure_pipes_closed_inner()
|
||||
else: self._loop.call_later(CallbackSubprocess.PROCESS_PIPE_TIME, ensure_pipes_closed_inner)
|
||||
|
||||
def terminate(self, kill_delay: float = 1.0):
|
||||
"""
|
||||
Terminate child process if running
|
||||
:param kill_delay: if after kill_delay seconds the child process has not exited, escalate to SIGHUP and SIGKILL
|
||||
"""
|
||||
|
||||
RNS.log("terminate()", RNS.LOG_EXTREME)
|
||||
if not self.running: return
|
||||
|
||||
with exception.permit(SystemExit): os.kill(self._pid, signal.SIGTERM)
|
||||
|
||||
def kill():
|
||||
if process_exists(self._pid):
|
||||
RNS.log("kill()", RNS.LOG_EXTREME)
|
||||
with exception.permit(SystemExit):
|
||||
os.kill(self._pid, signal.SIGHUP)
|
||||
os.kill(self._pid, signal.SIGKILL)
|
||||
|
||||
self._loop.call_later(kill_delay, kill)
|
||||
|
||||
def wait():
|
||||
RNS.log("wait()", RNS.LOG_EXTREME)
|
||||
with contextlib.suppress(OSError): os.waitpid(self._pid, 0)
|
||||
self._ensure_pipes_closed()
|
||||
RNS.log("wait() finish", RNS.LOG_EXTREME)
|
||||
|
||||
threading.Thread(target=wait, daemon=True).start()
|
||||
|
||||
def close_stdin(self):
|
||||
with contextlib.suppress(Exception):
|
||||
os.close(self._child_stdin)
|
||||
# Encourage prompt shutdown if child lingers after stdin close
|
||||
def _ensure_terminate():
|
||||
if self.running:
|
||||
self.terminate(kill_delay=0.2)
|
||||
if not self._loop.is_closed():
|
||||
self._loop.call_later(0.05, _ensure_terminate)
|
||||
|
||||
@property
|
||||
def started(self) -> bool:
|
||||
"""
|
||||
:return: True if child process has been started
|
||||
"""
|
||||
return self._pid is not None
|
||||
|
||||
@property
|
||||
def running(self) -> bool:
|
||||
"""
|
||||
:return: True if child process is still running
|
||||
"""
|
||||
return self._pid is not None and process_exists(self._pid)
|
||||
|
||||
def write(self, data: bytes):
|
||||
"""
|
||||
Write bytes to the stdin of the child process.
|
||||
:param data: bytes to write
|
||||
"""
|
||||
|
||||
os.write(self._child_stdin, data)
|
||||
|
||||
# TODO: Check what this is actually supposed to solve.
|
||||
#
|
||||
# For pipe-in + TTY-out, echo should be visible immediately
|
||||
if self._stdin_is_pipe and not self._stdout_is_pipe and self._stdout_cb is not None and data not in (CTRL_C, CTRL_D):
|
||||
try: self._stdout_cb(data)
|
||||
except Exception: pass
|
||||
|
||||
def set_winsize(self, r: int, c: int, h: int, v: int):
|
||||
"""
|
||||
Set the window size on the tty of the child process.
|
||||
:param r: rows visible
|
||||
:param c: columns visible
|
||||
:param h: horizontal pixels visible
|
||||
:param v: vertical pixels visible
|
||||
:return:
|
||||
"""
|
||||
RNS.log(f"set_winsize({r},{c},{h},{v}", RNS.LOG_DEBUG)
|
||||
tty_set_winsize(self._child_stdout, r, c, h, v)
|
||||
|
||||
def copy_winsize(self, fromfd: int):
|
||||
"""
|
||||
Copy window size from one tty to another.
|
||||
:param fromfd: source tty file descriptor
|
||||
"""
|
||||
r, c, h, v = tty_get_winsize(fromfd)
|
||||
self.set_winsize(r, c, h, v)
|
||||
|
||||
def tcsetattr(self, when: int, attr: list[any]): # actual type is list[int | list[int | bytes]]
|
||||
"""
|
||||
Set tty attributes.
|
||||
:param when: when to apply change: termios.TCSANOW or termios.TCSADRAIN or termios.TCSAFLUSH
|
||||
:param attr: attributes to set
|
||||
"""
|
||||
termios.tcsetattr(self._child_stdin, when, attr)
|
||||
|
||||
def tcgetattr(self) -> list[any]: # actual type is list[int | list[int | bytes]]
|
||||
"""
|
||||
Get tty attributes.
|
||||
:return: tty attributes value
|
||||
"""
|
||||
return termios.tcgetattr(self._child_stdout)
|
||||
|
||||
def ttysetraw(self):
|
||||
tty.setraw(self._child_stdout, termios.TCSADRAIN)
|
||||
|
||||
def start(self):
|
||||
"""
|
||||
Start the child process.
|
||||
"""
|
||||
RNS.log("start()", RNS.LOG_EXTREME)
|
||||
|
||||
# # Using the parent environment seems to do some weird stuff, at least on macOS
|
||||
# parentenv = os.environ.copy()
|
||||
# env = {"HOME": parentenv["HOME"],
|
||||
# "PATH": parentenv["PATH"],
|
||||
# "TERM": self._term if self._term is not None else parentenv.get("TERM", "xterm"),
|
||||
# "LANG": parentenv.get("LANG"),
|
||||
# "SHELL": self._command[0]}
|
||||
|
||||
env = os.environ.copy()
|
||||
for key in self._env:
|
||||
env[key] = self._env[key]
|
||||
|
||||
program = self._command[0]
|
||||
assert isinstance(program, str)
|
||||
|
||||
# match = re.search("^/bin/(.*sh)$", program)
|
||||
# if match:
|
||||
# self._command[0] = "-" + match.group(1)
|
||||
# env["SHELL"] = program
|
||||
# self._log.debug(f"set login shell {self._command}")
|
||||
|
||||
self._pid, \
|
||||
self._child_stdin, \
|
||||
self._child_stdout, \
|
||||
self._child_stderr = _launch_child(self._command, env, self._stdin_is_pipe, self._stdout_is_pipe,
|
||||
self._stderr_is_pipe)
|
||||
RNS.log(f"Started pid {self.pid}, fds: {self._child_stdin}, {self._child_stdout}, {self._child_stderr}", RNS.LOG_DEBUG)
|
||||
|
||||
def poll():
|
||||
try:
|
||||
pid, self._return_code = os.waitpid(self._pid, os.WNOHANG)
|
||||
if self._return_code is not None:
|
||||
self._return_code = self._return_code & 0xff
|
||||
if self._return_code is not None and not process_exists(self._pid):
|
||||
RNS.log(f"polled return code {self._return_code}", RNS.LOG_DEBUG)
|
||||
self._terminated_cb(self._return_code)
|
||||
if self.running:
|
||||
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
|
||||
else:
|
||||
self._ensure_pipes_closed()
|
||||
except Exception as e:
|
||||
if not hasattr(e, "errno") or e.errno != errno.ECHILD:
|
||||
RNS.log(f"Error in process poll: {e}", RNS.LOG_DEBUG)
|
||||
|
||||
self._loop.call_later(CallbackSubprocess.PROCESS_POLL_TIME, poll)
|
||||
|
||||
def stdout():
|
||||
try:
|
||||
with exception.permit(SystemExit):
|
||||
data = tty_read_poll(self._child_stdout)
|
||||
if data is not None and len(data) > 0:
|
||||
self._stdout_cb(data)
|
||||
# Opportunistically drain shortly after to coalesce immediate follow-up output
|
||||
if not self._loop.is_closed():
|
||||
self._loop.call_later(0.01, stdout)
|
||||
except EOFError:
|
||||
self._stdout_eof = True
|
||||
tty_unset_reader_callbacks(self._child_stdout)
|
||||
self._stdout_cb(bytearray())
|
||||
|
||||
def stderr():
|
||||
try:
|
||||
with exception.permit(SystemExit):
|
||||
data = tty_read_poll(self._child_stderr)
|
||||
if data is not None and len(data) > 0:
|
||||
self._stderr_cb(data)
|
||||
if not self._loop.is_closed():
|
||||
self._loop.call_later(0.01, stderr)
|
||||
except EOFError:
|
||||
self._stderr_eof = True
|
||||
tty_unset_reader_callbacks(self._child_stderr)
|
||||
self._stderr_cb(bytearray())
|
||||
|
||||
tty_add_reader_callback(self._child_stdout, stdout, self._loop)
|
||||
if self._child_stderr != self._child_stdout:
|
||||
tty_add_reader_callback(self._child_stderr, stderr, self._loop)
|
||||
|
||||
@property
|
||||
def stdout_eof(self):
|
||||
return self._stdout_eof or not self.running
|
||||
|
||||
@property
|
||||
def stderr_eof(self):
|
||||
return self._stderr_eof or not self.running
|
||||
|
||||
|
||||
@property
|
||||
def return_code(self) -> int:
|
||||
return self._return_code
|
||||
|
||||
@property
|
||||
def pid(self) -> int:
|
||||
return self._pid
|
||||
|
||||
|
||||
async def main():
|
||||
"""
|
||||
A test driver for the CallbackProcess class.
|
||||
python ./process.py /bin/zsh --login
|
||||
"""
|
||||
|
||||
if len(sys.argv) <= 1:
|
||||
print(f"Usage: {sys.argv} <absolute_path_to_child_executable> [child_arg ...]")
|
||||
exit(1)
|
||||
|
||||
loop = asyncio.get_event_loop()
|
||||
# asyncio.set_event_loop(loop)
|
||||
retcode = loop.create_future()
|
||||
|
||||
def stdout(data: bytes): os.write(sys.stdout.fileno(), data)
|
||||
|
||||
def terminated(rc: int): retcode.set_result(rc)
|
||||
|
||||
process = CallbackSubprocess(argv=sys.argv[1:],
|
||||
env={"TERM": os.environ.get("TERM", "xterm")},
|
||||
loop=loop,
|
||||
stdout_callback=stdout,
|
||||
terminated_callback=terminated)
|
||||
|
||||
def sigint_handler(sig, frame):
|
||||
if process is None or process.started and not process.running:
|
||||
raise KeyboardInterrupt
|
||||
elif process.running:
|
||||
process.write("\x03".encode("utf-8"))
|
||||
|
||||
def sigwinch_handler(sig, frame):
|
||||
process.copy_winsize(sys.stdin.fileno())
|
||||
|
||||
signal.signal(signal.SIGINT, sigint_handler)
|
||||
signal.signal(signal.SIGWINCH, sigwinch_handler)
|
||||
|
||||
def stdin():
|
||||
try:
|
||||
data = tty_read(sys.stdin.fileno())
|
||||
if data is not None:
|
||||
process.write(data)
|
||||
|
||||
except EOFError:
|
||||
tty_unset_reader_callbacks(sys.stdin.fileno())
|
||||
process.write(CTRL_D)
|
||||
|
||||
tty_add_reader_callback(sys.stdin.fileno(), stdin)
|
||||
process.start()
|
||||
# call_soon called it too soon, not sure why.
|
||||
loop.call_later(0.001, functools.partial(process.copy_winsize, sys.stdin.fileno()))
|
||||
|
||||
val = await retcode
|
||||
RNS.log(f"Got return code {val}", RNS.LOG_DEBUG)
|
||||
return val
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
tr = TTYRestorer(sys.stdin.fileno())
|
||||
try:
|
||||
tr.raw()
|
||||
asyncio.run(main())
|
||||
finally:
|
||||
tty_unset_reader_callbacks(sys.stdin.fileno())
|
||||
tr.restore()
|
||||
115
RNS/Utilities/rnsh/protocol.py
Normal file
115
RNS/Utilities/rnsh/protocol.py
Normal file
|
|
@ -0,0 +1,115 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import RNS
|
||||
from RNS.vendor import umsgpack
|
||||
from RNS.Buffer import StreamDataMessage as RNSStreamDataMessage
|
||||
import rnsh.retry
|
||||
import abc
|
||||
import contextlib
|
||||
import struct
|
||||
from abc import ABC, abstractmethod
|
||||
|
||||
MSG_MAGIC = 0xac
|
||||
PROTOCOL_VERSION = 1
|
||||
|
||||
def _make_MSGTYPE(val: int):
|
||||
return ((MSG_MAGIC << 8) & 0xff00) | (val & 0x00ff)
|
||||
|
||||
|
||||
class NoopMessage(RNS.MessageBase):
|
||||
MSGTYPE = _make_MSGTYPE(0)
|
||||
def pack(self) -> bytes: return bytes()
|
||||
def unpack(self, raw): pass
|
||||
|
||||
|
||||
class WindowSizeMessage(RNS.MessageBase):
|
||||
MSGTYPE = _make_MSGTYPE(2)
|
||||
|
||||
def __init__(self, rows: int = None, cols: int = None, hpix: int = None, vpix: int = None):
|
||||
super().__init__()
|
||||
self.rows = rows
|
||||
self.cols = cols
|
||||
self.hpix = hpix
|
||||
self.vpix = vpix
|
||||
|
||||
def pack(self) -> bytes: return umsgpack.packb((self.rows, self.cols, self.hpix, self.vpix))
|
||||
def unpack(self, raw): self.rows, self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
|
||||
|
||||
|
||||
class ExecuteCommandMesssage(RNS.MessageBase):
|
||||
MSGTYPE = _make_MSGTYPE(3)
|
||||
|
||||
def __init__(self, cmdline: [str] = None, pipe_stdin: bool = False, pipe_stdout: bool = False,
|
||||
pipe_stderr: bool = False, tcflags: [any] = None, term: str | None = None, rows: int = None,
|
||||
cols: int = None, hpix: int = None, vpix: int = None):
|
||||
|
||||
super().__init__()
|
||||
self.cmdline = cmdline
|
||||
self.pipe_stdin = pipe_stdin
|
||||
self.pipe_stdout = pipe_stdout
|
||||
self.pipe_stderr = pipe_stderr
|
||||
self.tcflags = tcflags
|
||||
self.term = term
|
||||
self.rows = rows
|
||||
self.cols = cols
|
||||
self.hpix = hpix
|
||||
self.vpix = vpix
|
||||
|
||||
def pack(self) -> bytes:
|
||||
return umsgpack.packb((self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr,
|
||||
self.tcflags, self.term, self.rows, self.cols, self.hpix, self.vpix))
|
||||
|
||||
def unpack(self, raw):
|
||||
self.cmdline, self.pipe_stdin, self.pipe_stdout, self.pipe_stderr, self.tcflags, self.term, self.rows, \
|
||||
self.cols, self.hpix, self.vpix = umsgpack.unpackb(raw)
|
||||
|
||||
|
||||
# Create a version of RNS.Buffer.StreamDataMessage that we control
|
||||
class StreamDataMessage(RNSStreamDataMessage):
|
||||
MSGTYPE = _make_MSGTYPE(4)
|
||||
STREAM_ID_STDIN = 0
|
||||
STREAM_ID_STDOUT = 1
|
||||
STREAM_ID_STDERR = 2
|
||||
|
||||
|
||||
class VersionInfoMessage(RNS.MessageBase):
|
||||
MSGTYPE = _make_MSGTYPE(5)
|
||||
|
||||
def __init__(self, sw_version: str = None):
|
||||
super().__init__()
|
||||
self.sw_version = sw_version or rnsh.__version__
|
||||
self.protocol_version = PROTOCOL_VERSION
|
||||
|
||||
def pack(self) -> bytes: return umsgpack.packb((self.sw_version, self.protocol_version))
|
||||
def unpack(self, raw): self.sw_version, self.protocol_version = umsgpack.unpackb(raw)
|
||||
|
||||
|
||||
class ErrorMessage(RNS.MessageBase):
|
||||
MSGTYPE = _make_MSGTYPE(6)
|
||||
|
||||
def __init__(self, msg: str = None, fatal: bool = False, data: dict = None):
|
||||
super().__init__()
|
||||
self.msg = msg
|
||||
self.fatal = fatal
|
||||
self.data = data
|
||||
|
||||
def pack(self) -> bytes: return umsgpack.packb((self.msg, self.fatal, self.data))
|
||||
def unpack(self, raw: bytes): self.msg, self.fatal, self.data = umsgpack.unpackb(raw)
|
||||
|
||||
|
||||
class CommandExitedMessage(RNS.MessageBase):
|
||||
MSGTYPE = _make_MSGTYPE(7)
|
||||
|
||||
def __init__(self, return_code: int = None):
|
||||
super().__init__()
|
||||
self.return_code = return_code
|
||||
|
||||
def pack(self) -> bytes: return umsgpack.packb(self.return_code)
|
||||
def unpack(self, raw: bytes): self.return_code = umsgpack.unpackb(raw)
|
||||
|
||||
|
||||
message_types = [NoopMessage, VersionInfoMessage, WindowSizeMessage, ExecuteCommandMesssage, StreamDataMessage,
|
||||
CommandExitedMessage, ErrorMessage]
|
||||
|
||||
def register_message_types(channel: RNS.Channel.Channel):
|
||||
for message_type in message_types: channel.register_message_type(message_type)
|
||||
189
RNS/Utilities/rnsh/retry.py
Normal file
189
RNS/Utilities/rnsh/retry.py
Normal file
|
|
@ -0,0 +1,189 @@
|
|||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2023 Aaron Heise
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
import asyncio
|
||||
import threading
|
||||
import time
|
||||
import rnsh.exception as exception
|
||||
from typing import Callable
|
||||
from contextlib import AbstractContextManager
|
||||
import types
|
||||
import typing
|
||||
|
||||
|
||||
class RetryStatus:
|
||||
def __init__(self, tag: any, try_limit: int, wait_delay: float, retry_callback: Callable[[any, int], any],
|
||||
timeout_callback: Callable[[any, int], None], tries: int = 1):
|
||||
|
||||
self.tag = tag
|
||||
self.try_limit = try_limit
|
||||
self.tries = tries
|
||||
self.wait_delay = wait_delay
|
||||
self.retry_callback = retry_callback
|
||||
self.timeout_callback = timeout_callback
|
||||
self.try_time = time.time()
|
||||
self.completed = False
|
||||
|
||||
@property
|
||||
def ready(self):
|
||||
ready = time.time() > self.try_time + self.wait_delay
|
||||
RNS.log(f"ready check {self.tag} try_time {self.try_time} wait_delay {self.wait_delay} " +
|
||||
f"next_try {self.try_time + self.wait_delay} now {time.time()} " +
|
||||
f"exceeded {time.time() - self.try_time - self.wait_delay} ready {ready}", RNS.LOG_DEBUG)
|
||||
return ready
|
||||
|
||||
@property
|
||||
def timed_out(self):
|
||||
return self.ready and self.tries >= self.try_limit
|
||||
|
||||
def timeout(self):
|
||||
self.completed = True
|
||||
self.timeout_callback(self.tag, self.tries)
|
||||
|
||||
def retry(self) -> any:
|
||||
self.tries = self.tries + 1
|
||||
self.try_time = time.time()
|
||||
return self.retry_callback(self.tag, self.tries)
|
||||
|
||||
|
||||
class RetryThread(AbstractContextManager):
|
||||
def __init__(self, loop_period: float = 0.25, name: str = "retry thread"):
|
||||
self._loop_period = loop_period
|
||||
self._statuses: list[RetryStatus] = []
|
||||
self._tag_counter = 0
|
||||
self._lock = threading.RLock()
|
||||
self._run = True
|
||||
self._finished: asyncio.Future = None
|
||||
self._thread = threading.Thread(name=name, target=self._thread_run, daemon=True)
|
||||
self._thread.start()
|
||||
|
||||
def is_alive(self):
|
||||
return self._thread.is_alive()
|
||||
|
||||
def close(self, loop: asyncio.AbstractEventLoop = None) -> asyncio.Future:
|
||||
RNS.log("Stopping timer thread", RNS.LOG_DEBUG)
|
||||
if loop is None:
|
||||
self._run = False
|
||||
self._thread.join()
|
||||
return None
|
||||
else:
|
||||
self._finished = loop.create_future()
|
||||
return self._finished
|
||||
|
||||
def wait(self, timeout: float = None):
|
||||
if timeout:
|
||||
timeout = timeout + time.time()
|
||||
|
||||
while timeout is None or time.time() < timeout:
|
||||
with self._lock:
|
||||
task_count = len(self._statuses)
|
||||
if task_count == 0:
|
||||
return
|
||||
time.sleep(0.1)
|
||||
|
||||
|
||||
def _thread_run(self):
|
||||
while self._run and self._finished is None:
|
||||
time.sleep(self._loop_period)
|
||||
ready: list[RetryStatus] = []
|
||||
prune: list[RetryStatus] = []
|
||||
with self._lock: ready.extend(list(filter(lambda s: s.ready, self._statuses)))
|
||||
|
||||
for retry in ready:
|
||||
try:
|
||||
if not retry.completed:
|
||||
if retry.timed_out:
|
||||
RNS.log(f"Timed out {retry.tag} after {retry.try_limit} tries", RNS.LOG_DEBUG)
|
||||
retry.timeout()
|
||||
prune.append(retry)
|
||||
elif retry.ready:
|
||||
RNS.log(f"Retrying {retry.tag}, try {retry.tries + 1}/{retry.try_limit}", RNS.LOG_DEBUG)
|
||||
should_continue = retry.retry()
|
||||
if not should_continue: self.complete(retry.tag)
|
||||
|
||||
except Exception as e:
|
||||
RNS.log(f"Error processing retry id {retry.tag}: {e}", RNS.LOG_ERROR)
|
||||
prune.append(retry)
|
||||
|
||||
with self._lock:
|
||||
for retry in prune:
|
||||
RNS.log(f"pruned retry {retry.tag}, retry count {retry.tries}/{retry.try_limit}", RNS.LOG_DEBUG)
|
||||
with exception.permit(SystemExit): self._statuses.remove(retry)
|
||||
|
||||
if self._finished is not None: self._finished.set_result(None)
|
||||
|
||||
def _get_next_tag(self):
|
||||
self._tag_counter += 1
|
||||
return self._tag_counter
|
||||
|
||||
def has_tag(self, tag: any) -> bool:
|
||||
with self._lock: return next(filter(lambda s: s.tag == tag, self._statuses), None) is not None
|
||||
|
||||
def begin(self, try_limit: int, wait_delay: float, try_callback: Callable[[any, int], any],
|
||||
timeout_callback: Callable[[any, int], None]) -> any:
|
||||
|
||||
RNS.log(f"Running first try", RNS.LOG_DEBUG)
|
||||
tag = try_callback(None, 1)
|
||||
RNS.log(f"First try got id {tag}", RNS.LOG_DEBUG)
|
||||
|
||||
if not tag:
|
||||
RNS.log(f"Callback returned None/False/0, considering complete.", RNS.LOG_DEBUG)
|
||||
return None
|
||||
|
||||
with self._lock:
|
||||
if tag is None: tag = self._get_next_tag()
|
||||
self.complete(tag)
|
||||
|
||||
self._statuses.append(RetryStatus(tag=tag,
|
||||
tries=1,
|
||||
try_limit=try_limit,
|
||||
wait_delay=wait_delay,
|
||||
retry_callback=try_callback,
|
||||
timeout_callback=timeout_callback))
|
||||
|
||||
RNS.log(f"Added retry timer for {tag}", RNS.LOG_DEBUG)
|
||||
return tag
|
||||
|
||||
def complete(self, tag: any):
|
||||
assert tag is not None
|
||||
with self._lock:
|
||||
status = next(filter(lambda l: l.tag == tag, self._statuses), None)
|
||||
if status is not None:
|
||||
status.completed = True
|
||||
self._statuses.remove(status)
|
||||
RNS.log(f"completed {tag}", RNS.LOG_DEBUG)
|
||||
return
|
||||
|
||||
RNS.log(f"status not found to complete {tag}", RNS.LOG_DEBUG)
|
||||
|
||||
def complete_all(self):
|
||||
with self._lock:
|
||||
for status in self._statuses:
|
||||
status.completed = True
|
||||
RNS.log(f"completed {status.tag}", RNS.LOG_DEBUG)
|
||||
|
||||
self._statuses.clear()
|
||||
|
||||
def __exit__(self, __exc_type: typing.Type[BaseException], __exc_value: BaseException,
|
||||
__traceback: types.TracebackType) -> bool:
|
||||
self.close()
|
||||
return False
|
||||
161
RNS/Utilities/rnsh/rnsh.py
Normal file
161
RNS/Utilities/rnsh/rnsh.py
Normal file
|
|
@ -0,0 +1,161 @@
|
|||
#!/usr/bin/env python3
|
||||
|
||||
# MIT License
|
||||
#
|
||||
# Copyright (c) 2016-2022 Mark Qvist / unsigned.io
|
||||
#
|
||||
# Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
# of this software and associated documentation files (the "Software"), to deal
|
||||
# in the Software without restriction, including without limitation the rights
|
||||
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
# copies of the Software, and to permit persons to whom the Software is
|
||||
# furnished to do so, subject to the following conditions:
|
||||
#
|
||||
# The above copyright notice and this permission notice shall be included in all
|
||||
# copies or substantial portions of the Software.
|
||||
#
|
||||
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
# SOFTWARE.
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import asyncio
|
||||
import base64
|
||||
|
||||
import re
|
||||
import os
|
||||
import sys
|
||||
|
||||
import RNS
|
||||
import rnsh.process as process
|
||||
import rnsh.session as session
|
||||
import rnsh.args
|
||||
import rnsh.loop
|
||||
import rnsh.listener as listener
|
||||
import rnsh.initiator as initiator
|
||||
|
||||
APP_NAME = "rnsh"
|
||||
loop: asyncio.AbstractEventLoop | None = None
|
||||
|
||||
def _sanitize_service_name(service_name:str) -> str: return re.sub(r'\W+', '', service_name)
|
||||
|
||||
def prepare_identity(identity_path, service_name: str = None) -> tuple[RNS.Identity]:
|
||||
service_name = _sanitize_service_name(service_name or "")
|
||||
if identity_path is None:
|
||||
identity_path = RNS.Reticulum.identitypath + "/" + APP_NAME + \
|
||||
(f".{service_name}" if service_name and len(service_name) > 0 else "")
|
||||
|
||||
identity = None
|
||||
if os.path.isfile(identity_path):
|
||||
identity = RNS.Identity.from_file(identity_path)
|
||||
|
||||
if identity is None:
|
||||
RNS.log("No valid saved identity found, creating new...", RNS.LOG_INFO)
|
||||
identity = RNS.Identity()
|
||||
identity.to_file(identity_path)
|
||||
|
||||
return identity
|
||||
|
||||
|
||||
def print_identity(configdir, identitypath, service_name, include_destination: bool):
|
||||
reticulum = RNS.Reticulum(configdir=configdir, loglevel=RNS.LOG_INFO)
|
||||
if service_name and len(service_name) > 0:
|
||||
print(f"Using service name \"{service_name}\"")
|
||||
identity = prepare_identity(identitypath, service_name)
|
||||
destination = RNS.Destination(identity, RNS.Destination.IN, RNS.Destination.SINGLE, APP_NAME)
|
||||
print("Identity : " + str(identity))
|
||||
if include_destination:
|
||||
print("Listening on : " + RNS.prettyhexrep(destination.hash))
|
||||
|
||||
exit(0)
|
||||
|
||||
verbose_set = False
|
||||
|
||||
def ensure_config_directory():
|
||||
if os.path.isdir(os.path.expanduser("~/.config/rnsh")): return os.path.expanduser("~/.config/rnsh")
|
||||
elif os.path.isdir(os.path.expanduser("~/.rnsh")): return os.path.expanduser("~/.rnsh")
|
||||
else:
|
||||
try:
|
||||
os.makedirs(os.path.expanduser("~/.rnsh"))
|
||||
return os.path.expanduser("~/.rnsh")
|
||||
|
||||
except Exception as e:
|
||||
RNS.log(f"Could not get or create rnsh configuration directory, aborting", RNS.LOG_CRITICAL)
|
||||
os._exit(1)
|
||||
|
||||
|
||||
async def _rnsh_cli_main():
|
||||
global verbose_set
|
||||
args = rnsh.args.Args(sys.argv)
|
||||
verbose_set = args.verbose > 0
|
||||
|
||||
configdir = ensure_config_directory()
|
||||
|
||||
if args.print_identity:
|
||||
print_identity(args.config, args.identity, args.service_name, args.listen)
|
||||
return 0
|
||||
|
||||
if args.listen:
|
||||
allowed_file = None
|
||||
dest_len = (RNS.Reticulum.TRUNCATED_HASHLENGTH//8)*2
|
||||
if os.path.isfile(os.path.expanduser("~/.config/rnsh/allowed_identities")):
|
||||
allowed_file = os.path.expanduser("~/.config/rnsh/allowed_identities")
|
||||
elif os.path.isfile(os.path.expanduser("~/.rnsh/allowed_identities")):
|
||||
allowed_file = os.path.expanduser("~/.rnsh/allowed_identities")
|
||||
|
||||
await listener.listen(configdir=configdir,
|
||||
rnsconfigdir=args.config,
|
||||
command=args.command_line,
|
||||
identitypath=args.identity,
|
||||
service_name=args.service_name,
|
||||
verbosity=args.verbose,
|
||||
quietness=args.quiet,
|
||||
allowed=args.allowed,
|
||||
allowed_file=allowed_file,
|
||||
disable_auth=args.no_auth,
|
||||
announce_period=args.announce,
|
||||
no_remote_command=args.no_remote_cmd,
|
||||
remote_cmd_as_args=args.remote_cmd_as_args)
|
||||
return 0
|
||||
|
||||
if args.destination is not None:
|
||||
return_code = await initiator.initiate(configdir=configdir,
|
||||
rnsconfigdir=args.config,
|
||||
identitypath=args.identity,
|
||||
verbosity=args.verbose,
|
||||
quietness=args.quiet,
|
||||
noid=args.no_id,
|
||||
destination=args.destination,
|
||||
timeout=args.timeout,
|
||||
command=args.command_line
|
||||
)
|
||||
return return_code if args.mirror else 0
|
||||
else:
|
||||
print("")
|
||||
print(rnsh.args.usage)
|
||||
print("")
|
||||
return 1
|
||||
|
||||
|
||||
def main():
|
||||
global verbose_set
|
||||
return_code = 1
|
||||
exc = None
|
||||
try: return_code = asyncio.run(_rnsh_cli_main())
|
||||
except SystemExit: pass
|
||||
except KeyboardInterrupt: pass
|
||||
except Exception as ex:
|
||||
print(f"Unhandled exception: {ex}")
|
||||
exc = ex
|
||||
|
||||
process.tty_unset_reader_callbacks(0)
|
||||
if verbose_set and exc: raise exc
|
||||
sys.exit(return_code if return_code is not None else 255)
|
||||
|
||||
|
||||
if __name__ == "__main__": main()
|
||||
407
RNS/Utilities/rnsh/session.py
Normal file
407
RNS/Utilities/rnsh/session.py
Normal file
|
|
@ -0,0 +1,407 @@
|
|||
from __future__ import annotations
|
||||
import contextlib
|
||||
import functools
|
||||
import rnsh.exception as exception
|
||||
import asyncio
|
||||
import rnsh.process as process
|
||||
import rnsh.helpers as helpers
|
||||
import rnsh.protocol as protocol
|
||||
import enum
|
||||
from typing import TypeVar, Generic, Callable, List
|
||||
from abc import abstractmethod, ABC
|
||||
from multiprocessing import Manager
|
||||
import os
|
||||
import bz2
|
||||
import RNS
|
||||
|
||||
_TLink = TypeVar("_TLink")
|
||||
_TIdentity = TypeVar("_TIdentity")
|
||||
|
||||
class SEType(enum.IntEnum):
|
||||
SE_LINK_CLOSED = 0
|
||||
|
||||
class SessionException(Exception):
|
||||
def __init__(self, setype: SEType, msg: str, *args):
|
||||
super().__init__(msg, args)
|
||||
self.type = setype
|
||||
|
||||
class LSState(enum.IntEnum):
|
||||
LSSTATE_WAIT_IDENT = 1
|
||||
LSSTATE_WAIT_VERS = 2
|
||||
LSSTATE_WAIT_CMD = 3
|
||||
LSSTATE_RUNNING = 4
|
||||
LSSTATE_ERROR = 5
|
||||
LSSTATE_TEARDOWN = 6
|
||||
|
||||
|
||||
class LSOutletBase(ABC):
|
||||
@abstractmethod
|
||||
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]): raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]): raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def unset_link_closed_callback(self): raise NotImplemented()
|
||||
|
||||
@property
|
||||
@abstractmethod
|
||||
def rtt(self): raise NotImplemented()
|
||||
|
||||
@abstractmethod
|
||||
def teardown(self): raise NotImplemented()
|
||||
|
||||
|
||||
class ListenerSession:
|
||||
sessions: List[ListenerSession] = []
|
||||
allowed_identity_hashes: [any] = []
|
||||
allowed_file_identity_hashes: [any] = []
|
||||
allow_all: bool = False
|
||||
allow_remote_command: bool = False
|
||||
default_command: [str] = []
|
||||
remote_cmd_as_args = False
|
||||
|
||||
def __init__(self, outlet: LSOutletBase, channel: RNS.Channel.Channel, loop: asyncio.AbstractEventLoop):
|
||||
RNS.log(f"Session started for {outlet}", RNS.LOG_INFO)
|
||||
self.outlet = outlet
|
||||
self.channel = channel
|
||||
self.outlet.set_initiator_identified_callback(self._initiator_identified)
|
||||
self.outlet.set_link_closed_callback(self._link_closed)
|
||||
self.loop = loop
|
||||
self.state: LSState = None
|
||||
self.remote_identity = None
|
||||
self.term: str | None = None
|
||||
self.stdin_is_pipe: bool = False
|
||||
self.stdout_is_pipe: bool = False
|
||||
self.stderr_is_pipe: bool = False
|
||||
self.tcflags: [any] = None
|
||||
self.cmdline: [str] = None
|
||||
self.rows: int = 0
|
||||
self.cols: int = 0
|
||||
self.hpix: int = 0
|
||||
self.vpix: int = 0
|
||||
self.stdout_buf = bytearray()
|
||||
self.stdout_eof_sent = False
|
||||
self.stderr_buf = bytearray()
|
||||
self.stderr_eof_sent = False
|
||||
self.return_code: int | None = None
|
||||
self.return_code_sent = False
|
||||
self.process: process.CallbackSubprocess | None = None
|
||||
|
||||
if self.allow_all: self._set_state(LSState.LSSTATE_WAIT_VERS)
|
||||
else: self._set_state(LSState.LSSTATE_WAIT_IDENT)
|
||||
|
||||
self.sessions.append(self)
|
||||
protocol.register_message_types(self.channel)
|
||||
self.channel.add_message_handler(self._handle_message)
|
||||
|
||||
def _terminated(self, return_code: int):
|
||||
self.return_code = return_code
|
||||
|
||||
def _set_state(self, state: LSState, timeout_factor: float = 10.0):
|
||||
timeout = max(self.outlet.rtt * timeout_factor, max(self.outlet.rtt * 2, 10)) if timeout_factor is not None else None
|
||||
RNS.log(f"Set state: {state.name}, timeout {timeout}", RNS.LOG_DEBUG)
|
||||
orig_state = self.state
|
||||
self.state = state
|
||||
if timeout_factor is not None:
|
||||
self._call(functools.partial(self._check_protocol_timeout, lambda: self.state == orig_state, state.name), timeout)
|
||||
|
||||
def _call(self, func: callable, delay: float = 0):
|
||||
def call_inner():
|
||||
if delay == 0: func()
|
||||
else: self.loop.call_later(delay, func)
|
||||
|
||||
self.loop.call_soon_threadsafe(call_inner)
|
||||
|
||||
def send(self, message: RNS.MessageBase):
|
||||
self.channel.send(message)
|
||||
|
||||
def _protocol_error(self, name: str):
|
||||
self.terminate(f"Protocol error ({name})")
|
||||
|
||||
def _protocol_timeout_error(self, name: str):
|
||||
self.terminate(f"Protocol timeout error: {name}")
|
||||
|
||||
def terminate(self, error: str = None):
|
||||
with contextlib.suppress(Exception):
|
||||
RNS.log("Terminating session" + (f": {error}" if error else ""), RNS.LOG_DEBUG)
|
||||
if error and self.state != LSState.LSSTATE_TEARDOWN:
|
||||
with contextlib.suppress(Exception):
|
||||
self.send(protocol.ErrorMessage(error, True))
|
||||
|
||||
self.state = LSState.LSSTATE_ERROR
|
||||
self._terminate_process()
|
||||
self._call(self._prune, max(self.outlet.rtt * 3, process.CallbackSubprocess.PROCESS_PIPE_TIME+5))
|
||||
|
||||
def _prune(self):
|
||||
self.state = LSState.LSSTATE_TEARDOWN
|
||||
RNS.log("Pruning session", RNS.LOG_DEBUG)
|
||||
with contextlib.suppress(ValueError):
|
||||
self.sessions.remove(self)
|
||||
with contextlib.suppress(Exception):
|
||||
self.outlet.teardown()
|
||||
|
||||
def _check_protocol_timeout(self, fail_condition: Callable[[], bool], name: str):
|
||||
timeout = True
|
||||
try: timeout = self.state != LSState.LSSTATE_TEARDOWN and fail_condition()
|
||||
except Exception as e: RNS.log(f"Error in protocol timeout: {e}", RNS.LOG_ERROR)
|
||||
if timeout: self._protocol_timeout_error(name)
|
||||
|
||||
def _link_closed(self, outlet: LSOutletBase):
|
||||
outlet.unset_link_closed_callback()
|
||||
|
||||
if outlet != self.outlet:
|
||||
RNS.log("Link closed received from incorrect outlet", RNS.LOG_DEBUG)
|
||||
return
|
||||
|
||||
RNS.log(f"link_closed {outlet}", RNS.LOG_DEBUG)
|
||||
self.terminate()
|
||||
|
||||
def _initiator_identified(self, outlet, identity):
|
||||
if outlet != self.outlet:
|
||||
RNS.log("Identity received from incorrect outlet", RNS.LOG_DEBUG)
|
||||
return
|
||||
|
||||
RNS.log(f"initiator_identified {identity} on link {outlet}", RNS.LOG_INFO)
|
||||
if self.state not in [LSState.LSSTATE_WAIT_IDENT, LSState.LSSTATE_WAIT_VERS]:
|
||||
self._protocol_error(LSState.LSSTATE_WAIT_IDENT.name)
|
||||
|
||||
if not self.allow_all and identity.hash not in self.allowed_identity_hashes and identity.hash not in self.allowed_file_identity_hashes:
|
||||
self.terminate("Identity is not allowed.")
|
||||
|
||||
self.remote_identity = identity
|
||||
self._set_state(LSState.LSSTATE_WAIT_VERS)
|
||||
|
||||
@classmethod
|
||||
async def pump_all(cls) -> True:
|
||||
processed_any = False
|
||||
for session in cls.sessions:
|
||||
processed = session.pump()
|
||||
processed_any = processed_any or processed
|
||||
await asyncio.sleep(0)
|
||||
|
||||
|
||||
@classmethod
|
||||
async def terminate_all(cls, reason: str):
|
||||
for session in cls.sessions:
|
||||
session.terminate(reason)
|
||||
await asyncio.sleep(0)
|
||||
|
||||
def pump(self) -> bool:
|
||||
def compress_adaptive(buf: bytes):
|
||||
comp_tries = RNS.RawChannelWriter.COMPRESSION_TRIES
|
||||
comp_try = 1
|
||||
comp_success = False
|
||||
|
||||
chunk_len = len(buf)
|
||||
if chunk_len > RNS.RawChannelWriter.MAX_CHUNK_LEN:
|
||||
chunk_len = RNS.RawChannelWriter.MAX_CHUNK_LEN
|
||||
chunk_segment = None
|
||||
|
||||
chunk_segment = None
|
||||
max_data_len = self.channel.mdu - protocol.StreamDataMessage.OVERHEAD
|
||||
while chunk_len > 32 and comp_try < comp_tries:
|
||||
chunk_segment_length = int(chunk_len/comp_try)
|
||||
compressed_chunk = bz2.compress(buf[:chunk_segment_length])
|
||||
compressed_length = len(compressed_chunk)
|
||||
if compressed_length < max_data_len and compressed_length < chunk_segment_length:
|
||||
comp_success = True
|
||||
break
|
||||
else:
|
||||
comp_try += 1
|
||||
|
||||
if comp_success:
|
||||
diff = max_data_len - len(compressed_chunk)
|
||||
chunk = compressed_chunk
|
||||
processed_length = chunk_segment_length
|
||||
else:
|
||||
chunk = bytes(buf[:max_data_len])
|
||||
processed_length = len(chunk)
|
||||
|
||||
return comp_success, processed_length, chunk
|
||||
|
||||
try:
|
||||
if self.state != LSState.LSSTATE_RUNNING:
|
||||
return False
|
||||
elif not self.channel.is_ready_to_send():
|
||||
return False
|
||||
elif len(self.stderr_buf) > 0:
|
||||
comp_success, processed_length, data = compress_adaptive(self.stderr_buf)
|
||||
self.stderr_buf = self.stderr_buf[processed_length:]
|
||||
send_eof = self.process.stderr_eof and len(data) == 0 and not self.stderr_eof_sent
|
||||
self.stderr_eof_sent = self.stderr_eof_sent or send_eof
|
||||
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDERR,
|
||||
data, send_eof, comp_success)
|
||||
self.send(msg)
|
||||
if send_eof:
|
||||
self.stderr_eof_sent = True
|
||||
return True
|
||||
elif len(self.stdout_buf) > 0:
|
||||
comp_success, processed_length, data = compress_adaptive(self.stdout_buf)
|
||||
self.stdout_buf = self.stdout_buf[processed_length:]
|
||||
send_eof = self.process.stdout_eof and len(data) == 0 and not self.stdout_eof_sent
|
||||
self.stdout_eof_sent = self.stdout_eof_sent or send_eof
|
||||
msg = protocol.StreamDataMessage(protocol.StreamDataMessage.STREAM_ID_STDOUT,
|
||||
data, send_eof, comp_success)
|
||||
self.send(msg)
|
||||
if send_eof:
|
||||
self.stdout_eof_sent = True
|
||||
return True
|
||||
elif self.return_code is not None and not self.return_code_sent:
|
||||
msg = protocol.CommandExitedMessage(self.return_code)
|
||||
self.send(msg)
|
||||
self.return_code_sent = True
|
||||
self._call(functools.partial(self._check_protocol_timeout,
|
||||
lambda: self.state == LSState.LSSTATE_RUNNING, "CommandExitedMessage"),
|
||||
max(self.outlet.rtt * 5, 10))
|
||||
return False
|
||||
|
||||
except Exception as e: RNS.log(f"Error during pump: {e}", RNS.LOG_ERROR)
|
||||
return False
|
||||
|
||||
def _terminate_process(self):
|
||||
with contextlib.suppress(Exception):
|
||||
if self.process and self.process.running:
|
||||
self.process.terminate()
|
||||
|
||||
def _start_cmd(self, cmdline: [str], pipe_stdin: bool, pipe_stdout: bool, pipe_stderr: bool, tcflags: [any],
|
||||
term: str | None, rows: int, cols: int, hpix: int, vpix: int):
|
||||
|
||||
self.cmdline = self.default_command
|
||||
if not self.allow_remote_command and cmdline and len(cmdline) > 0:
|
||||
self.terminate("Remote command line not allowed by listener")
|
||||
return
|
||||
|
||||
if self.remote_cmd_as_args and cmdline and len(cmdline) > 0:
|
||||
self.cmdline.extend(cmdline)
|
||||
elif cmdline and len(cmdline) > 0:
|
||||
self.cmdline = cmdline
|
||||
|
||||
|
||||
self.stdin_is_pipe = pipe_stdin
|
||||
self.stdout_is_pipe = pipe_stdout
|
||||
self.stderr_is_pipe = pipe_stderr
|
||||
self.tcflags = tcflags
|
||||
self.term = term
|
||||
|
||||
def stdout(data: bytes):
|
||||
self.stdout_buf.extend(data)
|
||||
|
||||
def stderr(data: bytes):
|
||||
self.stderr_buf.extend(data)
|
||||
|
||||
try:
|
||||
self.process = process.CallbackSubprocess(argv=self.cmdline,
|
||||
env={"TERM": self.term or os.environ.get("TERM") or "xterm",
|
||||
"RNS_REMOTE_IDENTITY": (RNS.prettyhexrep(self.remote_identity.hash)
|
||||
if self.remote_identity and self.remote_identity.hash else "")},
|
||||
loop=self.loop,
|
||||
stdout_callback=stdout,
|
||||
stderr_callback=stderr,
|
||||
terminated_callback=self._terminated,
|
||||
stdin_is_pipe=self.stdin_is_pipe,
|
||||
stdout_is_pipe=self.stdout_is_pipe,
|
||||
stderr_is_pipe=self.stderr_is_pipe)
|
||||
self.process.start()
|
||||
self._set_window_size(rows, cols, hpix, vpix)
|
||||
except Exception as e:
|
||||
RNS.log(f"Unable to start process for link {self.outlet}: {e}", RNS.LOG_ERROR)
|
||||
self.terminate("Unable to start process")
|
||||
|
||||
def _set_window_size(self, rows: int, cols: int, hpix: int, vpix: int):
|
||||
self.rows = rows
|
||||
self.cols = cols
|
||||
self.hpix = hpix
|
||||
self.vpix = vpix
|
||||
with contextlib.suppress(Exception):
|
||||
self.process.set_winsize(rows, cols, hpix, vpix)
|
||||
|
||||
def _received_stdin(self, data: bytes, eof: bool):
|
||||
if data and len(data) > 0:
|
||||
self.process.write(data)
|
||||
if eof:
|
||||
self.process.close_stdin()
|
||||
|
||||
def _handle_message(self, message: RNS.MessageBase):
|
||||
if self.state == LSState.LSSTATE_WAIT_IDENT:
|
||||
# Ignore any messages until the initiator has identified to avoid race conditions
|
||||
# between identity announcement and early protocol messages.
|
||||
RNS.log("Ignoring message while waiting for identification", RNS.LOG_DEBUG)
|
||||
return
|
||||
if self.state == LSState.LSSTATE_WAIT_VERS:
|
||||
if not isinstance(message, protocol.VersionInfoMessage):
|
||||
self._protocol_error(self.state.name)
|
||||
return
|
||||
RNS.log(f"Version {message.sw_version}, protocol {message.protocol_version} on link {self.outlet}", RNS.LOG_VERBOSE)
|
||||
if message.protocol_version != protocol.PROTOCOL_VERSION:
|
||||
self.terminate("Incompatible protocol")
|
||||
return
|
||||
self.send(protocol.VersionInfoMessage())
|
||||
self._set_state(LSState.LSSTATE_WAIT_CMD)
|
||||
return
|
||||
elif self.state == LSState.LSSTATE_WAIT_CMD:
|
||||
if not isinstance(message, protocol.ExecuteCommandMesssage):
|
||||
return self._protocol_error(self.state.name)
|
||||
RNS.log(f"Execute command message on link {self.outlet}: {message.cmdline}", RNS.LOG_VERBOSE)
|
||||
self._set_state(LSState.LSSTATE_RUNNING)
|
||||
self._start_cmd(message.cmdline, message.pipe_stdin, message.pipe_stdout, message.pipe_stderr,
|
||||
message.tcflags, message.term, message.rows, message.cols, message.hpix, message.vpix)
|
||||
return
|
||||
elif self.state == LSState.LSSTATE_RUNNING:
|
||||
if isinstance(message, protocol.WindowSizeMessage):
|
||||
self._set_window_size(message.rows, message.cols, message.hpix, message.vpix)
|
||||
elif isinstance(message, protocol.StreamDataMessage):
|
||||
if message.stream_id != protocol.StreamDataMessage.STREAM_ID_STDIN:
|
||||
RNS.log(f"Received stream data for invalid stream {message.stream_id} on link {self.outlet}", RNS.LOG_ERROR)
|
||||
return self._protocol_error(self.state.name)
|
||||
self._received_stdin(message.data, message.eof)
|
||||
return
|
||||
elif isinstance(message, protocol.NoopMessage):
|
||||
# echo noop only on listener--used for keepalive/connectivity check
|
||||
self.send(message)
|
||||
return
|
||||
elif self.state in [LSState.LSSTATE_ERROR, LSState.LSSTATE_TEARDOWN]:
|
||||
RNS.log(f"Received packet, but in state {self.state.name}", RNS.LOG_ERROR)
|
||||
return
|
||||
else:
|
||||
self._protocol_error("unexpected message")
|
||||
return
|
||||
|
||||
|
||||
class RNSOutlet(LSOutletBase):
|
||||
|
||||
def set_initiator_identified_callback(self, cb: Callable[[LSOutletBase, _TIdentity], None]):
|
||||
def inner_cb(link, identity: _TIdentity):
|
||||
cb(self, identity)
|
||||
|
||||
self.link.set_remote_identified_callback(inner_cb)
|
||||
|
||||
def set_link_closed_callback(self, cb: Callable[[LSOutletBase], None]):
|
||||
def inner_cb(link):
|
||||
cb(self)
|
||||
|
||||
self.link.set_link_closed_callback(inner_cb)
|
||||
|
||||
def unset_link_closed_callback(self):
|
||||
self.link.set_link_closed_callback(None)
|
||||
|
||||
def teardown(self):
|
||||
self.link.teardown()
|
||||
|
||||
@property
|
||||
def rtt(self) -> float:
|
||||
return self.link.rtt
|
||||
|
||||
def __str__(self):
|
||||
return f"Outlet RNS Link {self.link}"
|
||||
|
||||
def __init__(self, link: RNS.Link):
|
||||
self.link = link
|
||||
link.lsoutlet = self
|
||||
|
||||
@staticmethod
|
||||
def get_outlet(link: RNS.Link):
|
||||
if hasattr(link, "lsoutlet"):
|
||||
return link.lsoutlet
|
||||
|
||||
return RNSOutlet(link)
|
||||
Loading…
Add table
Add a link
Reference in a new issue