Skip to content
This repository has been archived by the owner on Nov 10, 2023. It is now read-only.

Code dup #33

Open
wants to merge 13 commits into
base: master
Choose a base branch
from
5 changes: 5 additions & 0 deletions netdev/connections/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""
Connections Module, classes that handle the protocols connection like ssh,telnet and serial.
"""
from .ssh import SSHConnection
from .telnet import TelnetConnection
100 changes: 100 additions & 0 deletions netdev/connections/base.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,100 @@
"""
Base Connection Module
"""
import re
import asyncio
from netdev.logger import logger
from .interface import IConnection


class BaseConnection(IConnection):

def __init__(self, *args, **kwargs):
self._host = None
self._timeout = None
self._conn = None
self._base_prompt = self._base_pattern = ""
self._MAX_BUFFER = 65535

async def __aenter__(self):
"""Async Context Manager"""
await self.connect()
selfuryon marked this conversation as resolved.
Show resolved Hide resolved
return self

async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async Context Manager"""
await self.disconnect()

@property
def _logger(self):
return logger

def set_base_prompt(self, prompt):
""" base prompt setter """
self._base_prompt = prompt

def set_base_pattern(self, pattern):
""" base patter setter """
self._base_pattern = pattern

async def disconnect(self):
""" Close Connection """
raise NotImplementedError("Connection must implement disconnect method")

async def connect(self):
""" Establish Connection """
raise NotImplementedError("Connection must implement connect method")

def send(self, cmd):
""" send data """
raise NotImplementedError("Connection must implement send method")

async def read(self):
""" read from buffer """
raise NotImplementedError("Connection must implement read method ")

async def read_until_pattern(self, pattern, re_flags=0):
"""Read channel until pattern detected. Return ALL data available"""

if pattern is None:
raise ValueError("pattern cannot be None")

if isinstance(pattern, str):
pattern = [pattern]
output = ""
logger.info("Host {}: Reading until pattern".format(self._host))

logger.debug("Host {}: Reading pattern: {}".format(self._host, pattern))
while True:

fut = self.read()
try:
output += await asyncio.wait_for(fut, self._timeout)
except asyncio.TimeoutError:
raise TimeoutError(self._host)

for exp in pattern:
if re.search(exp, output, flags=re_flags):
logger.debug(
"Host {}: Reading pattern '{}' was found: {}".format(
self._host, pattern, repr(output)
)
)
return output

async def read_until_prompt(self):
""" read util prompt """
return await self.read_until_pattern(self._base_pattern)

async def read_until_prompt_or_pattern(self, pattern, re_flags=0):
""" read util prompt or pattern """

logger.info("Host {}: Reading until prompt or pattern".format(self._host))

if isinstance(pattern, str):
pattern = [self._base_prompt, pattern]
elif isinstance(pattern, list):
pattern = [self._base_prompt] + pattern
else:
raise ValueError("pattern must be string or list of strings")
return await self.read_until_pattern(pattern=pattern, re_flags=re_flags)
52 changes: 52 additions & 0 deletions netdev/connections/interface.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
"""
Connection Interface
"""
import abc


class IConnection(abc.ABC):

@abc.abstractmethod
async def __aenter__(self):
"""Async Context Manager"""
pass

@abc.abstractmethod
async def __aexit__(self, exc_type, exc_val, exc_tb):
"""Async Context Manager"""
pass

@abc.abstractmethod
async def disconnect(self):
""" Close Connection """
pass

@abc.abstractmethod
async def connect(self):
""" Establish Connection """
pass

@abc.abstractmethod
async def send(self, cmd):
""" send Command """
pass

@abc.abstractmethod
async def read(self):
""" send Command """
pass

@abc.abstractmethod
async def read_until_pattern(self, pattern, re_flags=0):
""" read util pattern """
pass

@abc.abstractmethod
async def read_until_prompt(self):
""" read util pattern """
pass

@abc.abstractmethod
async def read_until_prompt_or_pattern(self, attern, re_flags=0):
""" read util pattern """
pass
Empty file added netdev/connections/serial.py
Empty file.
123 changes: 123 additions & 0 deletions netdev/connections/ssh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""
SSH Connection Module
"""
import asyncio
import asyncssh
from netdev.constants import TERM_LEN, TERM_WID, TERM_TYPE
from netdev.exceptions import DisconnectError
from .base import BaseConnection


class SSHConnection(BaseConnection):
def __init__(self,
host=u"",
username=u"",
password=u"",
port=22,
timeout=15,
loop=None,
known_hosts=None,
local_addr=None,
client_keys=None,
passphrase=None,
tunnel=None,
pattern=None,
agent_forwarding=False,
agent_path=(),
client_version=u"netdev-{}",
family=0,
kex_algs=(),
encryption_algs=(),
mac_algs=(),
compression_algs=(),
signature_algs=()):
super().__init__()
if host:
self._host = host
else:
raise ValueError("Host must be set")
self._port = int(port)
self._timeout = timeout
if loop is None:
self._loop = asyncio.get_event_loop()
else:
self._loop = loop

connect_params_dict = {
"host": self._host,
"port": self._port,
"username": username,
"password": password,
"known_hosts": known_hosts,
"local_addr": local_addr,
"client_keys": client_keys,
"passphrase": passphrase,
"tunnel": tunnel,
"agent_forwarding": agent_forwarding,
"loop": loop,
"family": family,
"agent_path": agent_path,
"client_version": client_version,
"kex_algs": kex_algs,
"encryption_algs": encryption_algs,
"mac_algs": mac_algs,
"compression_algs": compression_algs,
"signature_algs": signature_algs
}

if pattern is not None:
self._pattern = pattern

self._conn_dict = connect_params_dict
self._timeout = timeout

async def connect(self):
""" Etablish SSH connection """
self._logger.info("Host {}: SSH: Establishing SSH connection on port {}".format(self._host, self._port))

fut = asyncssh.connect(**self._conn_dict)
try:
self._conn = await asyncio.wait_for(fut, self._timeout)
except asyncssh.DisconnectError as e:
raise DisconnectError(self._host, e.code, e.reason)
except asyncio.TimeoutError:
raise TimeoutError(self._host)

await self._start_session()

async def disconnect(self):
""" Gracefully close the SSH connection """
self._logger.info("Host {}: SSH: Disconnecting".format(self._host))
self._logger.info("Host {}: SSH: Disconnecting".format(self._host))
await self._cleanup()
self._conn.close()
await self._conn.wait_closed()

def send(self, cmd):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We have some differences here.
Code here is def send(self, cmd), but in parent class is async def send(self, cmd). We need to do the same for removing the ambiguous.
I think that we can leave it a regular function like asyncssh's authors and asyncio's authors made it

self._stdin.write(cmd)

async def read(self):
return await self._stdout.read(self._MAX_BUFFER)

def __check_session(self):
""" check session was opened """
if not self._stdin:
raise RuntimeError("SSH session not started")

async def _start_session(self):
""" start interactive-session (shell) """
self._logger.info(
"Host {}: SSH: Starting Interacive session term_type={}, term_width={}, term_length={}".format(
self._host, TERM_TYPE, TERM_WID, TERM_LEN))
self._stdin, self._stdout, self._stderr = await self._conn.open_session(
term_type=TERM_TYPE, term_size=(TERM_WID, TERM_LEN)
)

async def _cleanup(self):
pass

async def close(self):
""" Close Connection """
await self._cleanup()
self._conn.close()
await self._conn.wait_closed()
82 changes: 82 additions & 0 deletions netdev/connections/telnet.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
"""
Telnet Connection Module
"""
import asyncio
from netdev.exceptions import DisconnectError, TimeoutError
from .base import BaseConnection


class TelnetConnection(BaseConnection):
def __init__(self,
host=u"",
username=u"",
password=u"",
port=23,
timeout=15,
loop=None,
pattern=None, ):
super().__init__()
if host:
self._host = host
else:
raise ValueError("Host must be set")
self._port = int(port)
self._timeout = timeout
self._username = username
self._password = password
if loop is None:
self._loop = asyncio.get_event_loop()
else:
self._loop = loop

if pattern is not None:
self._pattern = pattern

self._timeout = timeout

async def _start_session(self):
""" start Telnet Session by login to device """
self._logger.info("Host {}: telnet: trying to login to device".format(self._host))
output = await self.read_until_pattern(['username', 'Username'])
self.send(self._username + '\n')
output += await self.read_until_pattern(['password', 'Password'])
self.send(self._password + '\n')
output += await self.read_until_prompt()
self.send('\n')
if 'Login invalid' in output:
raise DisconnectError(self._host, None, "authentication failed")

def __check_session(self):
if not self._stdin:
raise RuntimeError("telnet session not started")

@asyncio.coroutine
def connect(self):
""" Establish Telnet Connection """
self._logger.info("Host {}: telnet: Establishing Telnet Connection on port {}".format(self._host, self._port))
fut = asyncio.open_connection(self._host, self._port, family=0, flags=0)
try:
self._stdout, self._stdin = yield from asyncio.wait_for(fut, self._timeout)
except asyncio.TimeoutError:
raise TimeoutError(self._host)
except Exception as e:
raise DisconnectError(self._host, None, str(e))

yield from self._start_session()

async def disconnect(self):
""" Gracefully close the Telnet connection """
self._logger.info("Host {}: telnet: Disconnecting".format(self._host))
self._logger.info("Host {}: telnet: Disconnecting".format(self._host))
self._conn.close()
await self._conn.wait_closed()

def send(self, cmd):
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The same thing like in here.
I'd prefer to leave it as a regular function, not a coroutine.

self._stdin.write(cmd.encode())

async def read(self):
output = await self._stdout.read(self._MAX_BUFFER)
return output.decode(errors='ignore')

async def close(self):
pass
Loading