diff --git a/pywattbox/driver/async_driver.py b/pywattbox/driver/async_driver.py index 186d5c8..bc3a841 100644 --- a/pywattbox/driver/async_driver.py +++ b/pywattbox/driver/async_driver.py @@ -1,6 +1,8 @@ from __future__ import annotations +import asyncio import logging +import re from collections.abc import Callable from io import BytesIO from typing import Any @@ -16,9 +18,37 @@ async def on_open(driver: WattBoxAsyncDriver) -> None: - # if driver.transport_name not in ("telnet", "asynctelnet"): logger.debug("On Open") - await driver.channel._read_until_prompt() + # The 800-series presents an in-channel "Username:"/"Password:" telnet login that + # scrapli's built-in telnet auth does not satisfy (the device rejects it as + # "Invalid Login"). Over telnet, bypass scrapli auth (see __init__) and log in + # manually here. Over SSH the transport already authenticates, so keep the + # original behaviour. + if driver.transport_name in ("telnet", "asynctelnet"): + ch = driver.channel + + async def _read_until(token: bytes, timeout: float = 8.0) -> bytes: + buf = b"" + loop = asyncio.get_event_loop() + end = loop.time() + timeout + while token not in buf and loop.time() < end: + buf += await ch.read() + return buf + + await _read_until(b"Username:") + ch.write(driver.auth_username) + ch.send_return() + await _read_until(b"Password:") + ch.write(driver.auth_password) + ch.send_return() + await _read_until(b"Logged In") + # consume the trailing "!\n" so the first command read starts clean + try: + await asyncio.wait_for(ch.read(), 0.4) + except Exception: + pass + else: + await driver.channel._read_until_prompt() async def on_close(driver: WattBoxAsyncDriver) -> None: @@ -57,6 +87,11 @@ def __init__( channel_lock: bool = True, logging_uid: str = "", ) -> None: + # scrapli's telnet auth does not work against the WattBox login prompt; + # bypass it and authenticate manually in on_open. SSH keeps normal auth. + if transport in ("telnet", "asynctelnet"): + auth_bypass = True + super().__init__( host=host, port=port, @@ -89,22 +124,13 @@ async def _open(self, force: bool = False) -> None: await self.open() @timeout_modifier - async def _send_command( - self, - command: str, - ) -> Response: - """Send a command. - - Based on: - scrapli.driver.generic.async_driver.GenericDriver: send_command and _send_command - scrapli.channel.async_channel.Channel: send_input + async def _send_command(self, command: str) -> Response: + """Send a command and return its single-line response. - Args: - command: string to send to device in privilege exec mode - failed_when_contains: string or list of strings indicating failure if found in response - - Returns: - Response: Scrapli Response object + WattBox replies one line per request as ``?Key=value`` (or ``OK`` / ``#Error`` + for ``!`` control messages). The device does not reliably echo the command, and + values can contain spaces and commas (e.g. ``?OutletName``), so scrapli's prompt + matching is unreliable here -- read the matching reply line directly instead. """ await self._open() @@ -116,43 +142,26 @@ async def _send_command( logger.debug("Sending Command: %s", command) - # Normally handled in the channel `send_input`, but WattBox is special and doesn't work - # with that function. Pulled it all into the Driver for simplicity. - async with self.channel._channel_lock(): - self.channel.write(command) - self.channel.send_return() - raw_response = await self.channel._read_until_prompt() - - logger.debug("raw_response: %s", raw_response) - split_response = raw_response.strip().splitlines() - logger.debug("split_response: %s", split_response) - if ( - self.transport not in ("telnet", "asynctelnet") - and len(split_response) < 2 - ): - logger.debug("Not enough lines: %s. Getting more", len(split_response)) - raw_response += await self.channel._read_until_prompt() - logger.debug("raw_response: %s", raw_response) - split_response = raw_response.strip().splitlines() - logger.debug("split_response: %s", split_response) - - if ( - self.transport not in ("telnet", "asynctelnet") - and split_response[0] != command.encode() - ): - logger.error("Doesn't match command: %s - %s", command, split_response[0]) - + key = command.split("=", 1)[0].encode() if command.startswith("?"): - if not split_response[-1].startswith(command.encode()): - logger.error( - "Expected response to start with: %s, Got %s", - command, - split_response[-1], - ) - processed_response = split_response[1].split(b"=")[-1] + reply_pattern = re.compile(b"(?m)^" + re.escape(key) + b"=(.*)") else: - processed_response = split_response[-1] + reply_pattern = re.compile(b"(OK|#Error)") + raw_response = b"" + async with self.channel._channel_lock(): + self.channel.write(command) + self.channel.send_return() + loop = asyncio.get_event_loop() + end = loop.time() + 6.0 + while loop.time() < end and not reply_pattern.search(raw_response): + try: + raw_response += await asyncio.wait_for(self.channel.read(), 1.0) + except Exception: + break + + match = reply_pattern.search(raw_response) + processed_response = match.group(1).rstrip() if match else b"" logger.debug("processed_response: %s", processed_response) response.record_response(processed_response) response.raw_result = raw_response