Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 60 additions & 51 deletions pywattbox/driver/async_driver.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
from __future__ import annotations

import asyncio
import logging
import re
from collections.abc import Callable
from io import BytesIO
from typing import Any
Expand All @@ -16,9 +18,37 @@


async def on_open(driver: WattBoxAsyncDriver) -> None:
# if driver.transport_name not in ("telnet", "asynctelnet"):
logger.debug("On Open")
await driver.channel._read_until_prompt()
# The 800-series presents an in-channel "Username:"/"Password:" telnet login that
# scrapli's built-in telnet auth does not satisfy (the device rejects it as
# "Invalid Login"). Over telnet, bypass scrapli auth (see __init__) and log in
# manually here. Over SSH the transport already authenticates, so keep the
# original behaviour.
if driver.transport_name in ("telnet", "asynctelnet"):
ch = driver.channel

async def _read_until(token: bytes, timeout: float = 8.0) -> bytes:
buf = b""
loop = asyncio.get_event_loop()
end = loop.time() + timeout
while token not in buf and loop.time() < end:
buf += await ch.read()
return buf

await _read_until(b"Username:")
ch.write(driver.auth_username)
ch.send_return()
await _read_until(b"Password:")
ch.write(driver.auth_password)
ch.send_return()
await _read_until(b"Logged In")
# consume the trailing "!\n" so the first command read starts clean
try:
await asyncio.wait_for(ch.read(), 0.4)
except Exception:
pass
else:
await driver.channel._read_until_prompt()


async def on_close(driver: WattBoxAsyncDriver) -> None:
Expand Down Expand Up @@ -57,6 +87,11 @@ def __init__(
channel_lock: bool = True,
logging_uid: str = "",
) -> None:
# scrapli's telnet auth does not work against the WattBox login prompt;
# bypass it and authenticate manually in on_open. SSH keeps normal auth.
if transport in ("telnet", "asynctelnet"):
auth_bypass = True

super().__init__(
host=host,
port=port,
Expand Down Expand Up @@ -89,22 +124,13 @@ async def _open(self, force: bool = False) -> None:
await self.open()

@timeout_modifier
async def _send_command(
self,
command: str,
) -> Response:
"""Send a command.

Based on:
scrapli.driver.generic.async_driver.GenericDriver: send_command and _send_command
scrapli.channel.async_channel.Channel: send_input
async def _send_command(self, command: str) -> Response:
"""Send a command and return its single-line response.

Args:
command: string to send to device in privilege exec mode
failed_when_contains: string or list of strings indicating failure if found in response

Returns:
Response: Scrapli Response object
WattBox replies one line per request as ``?Key=value`` (or ``OK`` / ``#Error``
for ``!`` control messages). The device does not reliably echo the command, and
values can contain spaces and commas (e.g. ``?OutletName``), so scrapli's prompt
matching is unreliable here -- read the matching reply line directly instead.
"""
await self._open()

Expand All @@ -116,43 +142,26 @@ async def _send_command(

logger.debug("Sending Command: %s", command)

# Normally handled in the channel `send_input`, but WattBox is special and doesn't work
# with that function. Pulled it all into the Driver for simplicity.
async with self.channel._channel_lock():
self.channel.write(command)
self.channel.send_return()
raw_response = await self.channel._read_until_prompt()

logger.debug("raw_response: %s", raw_response)
split_response = raw_response.strip().splitlines()
logger.debug("split_response: %s", split_response)
if (
self.transport not in ("telnet", "asynctelnet")
and len(split_response) < 2
):
logger.debug("Not enough lines: %s. Getting more", len(split_response))
raw_response += await self.channel._read_until_prompt()
logger.debug("raw_response: %s", raw_response)
split_response = raw_response.strip().splitlines()
logger.debug("split_response: %s", split_response)

if (
self.transport not in ("telnet", "asynctelnet")
and split_response[0] != command.encode()
):
logger.error("Doesn't match command: %s - %s", command, split_response[0])

key = command.split("=", 1)[0].encode()
if command.startswith("?"):
if not split_response[-1].startswith(command.encode()):
logger.error(
"Expected response to start with: %s, Got %s",
command,
split_response[-1],
)
processed_response = split_response[1].split(b"=")[-1]
reply_pattern = re.compile(b"(?m)^" + re.escape(key) + b"=(.*)")
else:
processed_response = split_response[-1]
reply_pattern = re.compile(b"(OK|#Error)")

raw_response = b""
async with self.channel._channel_lock():
self.channel.write(command)
self.channel.send_return()
loop = asyncio.get_event_loop()
end = loop.time() + 6.0
while loop.time() < end and not reply_pattern.search(raw_response):
try:
raw_response += await asyncio.wait_for(self.channel.read(), 1.0)
except Exception:
break

match = reply_pattern.search(raw_response)
processed_response = match.group(1).rstrip() if match else b""
logger.debug("processed_response: %s", processed_response)
response.record_response(processed_response)
response.raw_result = raw_response
Expand Down