Skip to content

Commit 22a481b

Browse files
committed
Merge branch 'pydanticv2-upgrade' PR# 8
Pydantic v2 upgrade Update to a newer generation of pydantic. Initial changes followed the guide here: https://docs.pydantic.dev/latest/migration/ The major challenge here was pyamplipi's use of pydantic;'s ModelMetaclass and ModelField which were made private by the pydantic devs (for good reasons). This code has been tested via: - upcoming automated tests in https://github.com/micro-nova/pyamplipi/tree/add-tests - via manual cli testing with an inhouse amplipi running 0.4.5
2 parents 422eaec + 5724718 commit 22a481b

File tree

4 files changed

+165
-157
lines changed

4 files changed

+165
-157
lines changed

pyamplipi/__main__.py

Lines changed: 46 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,12 @@
55
import os
66
import datetime
77
from argparse import ArgumentParser, ArgumentDefaultsHelpFormatter, Namespace, Action, ArgumentError
8+
import typing
89
from typing import Optional, List, Callable, Sequence, Dict, Any, Union, Type
910
from textwrap import indent
10-
import json
1111
import yaml
1212
from pydantic import BaseModel
13-
from pydantic.fields import ModelField
14-
from pydantic.main import ModelMetaclass
13+
from pydantic.fields import FieldInfo
1514
from dotenv import load_dotenv
1615
from aiohttp import TCPConnector, ClientSession
1716
from aiohttp.client_exceptions import ServerDisconnectedError
@@ -26,9 +25,9 @@
2625
# pylint: disable=logging-fstring-interpolation
2726

2827
# constants
29-
log = logging.getLogger(__name__) # central logging channel
30-
json_ser_kwargs: Dict[str, Any] = dict(
31-
exclude_unset=True, indent=2) # arguments to serialise the json
28+
log = logging.getLogger(__name__) # central logging channel
29+
# arguments to serialise the json
30+
json_ser_kwargs: Dict[str, Any] = {'exclude_unset': True, 'indent': 2}
3231

3332

3433
# text formatters
@@ -44,7 +43,7 @@ def table(d, h) -> str:
4443

4544
# simple list json for List[BaseModel] constructs
4645
def model_list_to_json(it: Sequence[BaseModel]) -> str:
47-
return f"[{','.join([i.json(**json_ser_kwargs) for i in it])}]"
46+
return f"[{','.join([i.model_dump_json(**json_ser_kwargs) for i in it])}]"
4847

4948

5049
# list methods dumping comprehensive output to stdout
@@ -166,7 +165,7 @@ async def do_status_get(args: Namespace, amplipi: AmpliPi, shell: bool, **kwargs
166165
"""
167166
log.debug("status.get()")
168167
status: Status = await amplipi.get_status()
169-
write_out(status.json(**json_ser_kwargs), args.outfile)
168+
write_out(status.model_dump_json(**json_ser_kwargs), args.outfile)
170169

171170

172171
async def do_config_load(args: Namespace, amplipi: AmpliPi, shell: bool, **kwargs):
@@ -227,7 +226,7 @@ async def do_info_get(args: Namespace, amplipi: AmpliPi, shell: bool, **kwargs):
227226
"""
228227
log.debug("status.info()")
229228
info: Info = await amplipi.get_info()
230-
write_out(info.json(**json_ser_kwargs), args.outfile)
229+
write_out(info.model_dump_json(**json_ser_kwargs), args.outfile)
231230

232231

233232
# -- source section
@@ -245,7 +244,7 @@ async def do_source_get(args: Namespace, amplipi: AmpliPi, shell: bool, **kwargs
245244
log.debug(f"source.get({args.sourceid})")
246245
assert 0 <= args.sourceid <= 3, "source id must be in range 0..3"
247246
source: Source = await amplipi.get_source(args.sourceid)
248-
write_out(source.json(**json_ser_kwargs), args.outfile)
247+
write_out(source.model_dump_json(**json_ser_kwargs), args.outfile)
249248

250249

251250
async def do_source_getall(args: Namespace, amplipi: AmpliPi, shell: bool, **kwargs):
@@ -297,7 +296,7 @@ async def do_zone_get(args: Namespace, amplipi: AmpliPi, shell: bool, **kwargs):
297296
log.debug(f"zone.get({args.zoneid})")
298297
assert 0 <= args.zoneid <= 35, "zone id must be in range 0..35"
299298
zone: Zone = await amplipi.get_zone(args.zoneid)
300-
write_out(zone.json(**json_ser_kwargs), args.outfile)
299+
write_out(zone.model_dump_json(**json_ser_kwargs), args.outfile)
301300

302301

303302
async def do_zone_getall(args: Namespace, amplipi: AmpliPi, shell: bool, **kwargs):
@@ -348,7 +347,7 @@ async def do_group_get(args: Namespace, amplipi: AmpliPi, shell: bool, **kwargs)
348347
log.debug(f"group.get({args.groupid})")
349348
assert 0 <= args.groupid, "group id must be > 0"
350349
group: Group = await amplipi.get_group(args.groupid)
351-
write_out(group.json(**json_ser_kwargs), args.outfile)
350+
write_out(group.model_dump_json(**json_ser_kwargs), args.outfile)
352351

353352

354353
async def do_group_getall(args: Namespace, amplipi: AmpliPi, shell: bool, **kwargs):
@@ -414,7 +413,7 @@ async def do_stream_get(args: Namespace, amplipi: AmpliPi, shell: bool, **kwargs
414413
log.debug(f"stream.get({args.streamid})")
415414
assert 0 <= args.streamid, "stream id must be > 0"
416415
stream: Stream = await amplipi.get_stream(args.streamid)
417-
write_out(stream.json(**json_ser_kwargs), args.outfile)
416+
write_out(stream.model_dump_json(**json_ser_kwargs), args.outfile)
418417

419418

420419
async def do_stream_getall(args: Namespace, amplipi: AmpliPi, shell: bool, **kwargs):
@@ -529,7 +528,7 @@ async def do_preset_get(args: Namespace, amplipi: AmpliPi, shell: bool, **kwargs
529528
log.debug(f"preset.get({args.presetid})")
530529
assert 0 <= args.presetid, "preset id must be > 0"
531530
preset: Preset = await amplipi.get_preset(args.presetid)
532-
write_out(preset.json(**json_ser_kwargs), args.outfile)
531+
write_out(preset.model_dump_json(**json_ser_kwargs), args.outfile)
533532

534533

535534
async def do_preset_getall(args: Namespace, amplipi: AmpliPi, shell: bool, **kwargs):
@@ -670,7 +669,7 @@ async def shell_cmd_exec(cmdline: str, amplipi: AmpliPi, argsparser: ArgumentPar
670669
print(e)
671670

672671

673-
def instantiate_model(model_cls: ModelMetaclass, infile: str, _input: Optional[Dict[str, Any]] = None,
672+
def instantiate_model(model_cls, infile: str, _input: Optional[Dict[str, Any]] = None,
674673
validate: Optional[Callable] = None):
675674
""" Instatiates the passed BaseModel based on:
676675
(1) either the passed input dict (if not None) merged with env var defaults
@@ -688,10 +687,10 @@ def instantiate_model(model_cls: ModelMetaclass, infile: str, _input: Optional[D
688687
validate(_input)
689688
return model_cls(**_input)
690689
# else read the object from stdin (json)
691-
return model_cls.parse_obj(json.loads(read_in(infile))) # type: ignore
690+
return model_cls.model_validate_json(read_in(infile)) # type: ignore
692691

693692

694-
def merge_model_kwargs(model_cls: ModelMetaclass, input: dict) -> Dict[str, Any]:
693+
def merge_model_kwargs(model_cls, _input: dict) -> Dict[str, Any]:
695694
""" Builds the kwargs needed to construct the passed BaseModel by merging the passed input dict
696695
with possible available environment variables with key following this pattern:
697696
"AMPLIPI_" + «name of BaseModel» + "_" + «name of field in BaseModel» (in all caps)
@@ -700,32 +699,40 @@ def envvar(name):
700699
envkey = f"AMPLIPI_{model_cls.__name__}_{name}".upper()
701700
return os.getenv(envkey)
702701
kwargs = dict()
703-
for name, modelfield in model_cls.__fields__.items(): # type: ignore
704-
value_str: str = input.get(name, envvar(name))
702+
for name, modelfield in model_cls.model_fields.items(): # type: ignore
703+
value_str: str = _input.get(name, envvar(name))
705704
if value_str is not None and isinstance(value_str, str) and len(value_str) > 0:
706705
value = parse_valuestr(value_str, modelfield)
707706
log.debug(
708-
f"converted {value_str} to {value} for {modelfield.type_}")
707+
f"converted {value_str} to {value} for {modelfield.annotation}")
709708
kwargs[name] = value
710709
return kwargs
711710

712711

713712
# helper functions for the arguments parsing
714-
def parse_valuestr(val_str: str, modelfield: ModelField):
715-
""" Uses the pydantic defined Modelfield to correctly parse CLI passed string-values to typed values
716-
Supports simple types and lists of them
717-
"""
718-
convertor = modelfield.type_
719-
if convertor == bool:
720-
def boolconvertor(s):
713+
def parse_valuestr(val_str: str, modelfield: FieldInfo):
714+
""" Uses the pydantic defined FieldInfo to correctly parse CLI passed string-values to typed values
715+
Supports simple types and lists of them which can be wrapped in Optional.
716+
TODO: This is fairly fragile. We should find a more robust solution.
717+
"""
718+
converter: Union[Type, None, Callable] = modelfield.annotation
719+
720+
if getattr(converter, '_name', None) == "Optional":
721+
# Optional needs to be manually unwrapped to the inner type
722+
converter = typing.get_args(converter)[0] # unwrap Optional
723+
if converter is bool:
724+
def boolconverter(s):
721725
return len(s) > 0 and s.lower() in ('y', 'yes', '1', 'true', 'on')
722-
convertor = boolconvertor
723-
if modelfield.outer_type_.__name__ == 'List':
726+
converter = boolconverter
727+
if converter is list:
724728
assert val_str[0] == '[' and val_str[-1] == ']', "expected array-value needs to be surrounded with []"
725729
val_str = val_str[1:-1]
726-
return [convertor(v.strip()) for v in val_str.split(',')]
727-
# else
728-
return convertor(val_str)
730+
return [converter(v.strip()) for v in val_str.split(',')]
731+
if converter is None:
732+
log.warning(
733+
f"no converter for {modelfield.title} not converting {val_str}")
734+
return val_str
735+
return converter(val_str)
729736

730737

731738
class ParseDict(Action):
@@ -756,7 +763,7 @@ def add_force_argument(ap: ArgumentParser):
756763
help="force the command to be executed without interaction.")
757764

758765

759-
def add_id_argument(ap: ArgumentParser, model_cls: ModelMetaclass):
766+
def add_id_argument(ap: ArgumentParser, model_cls):
760767
""" Adds the --input argument in a consistent way
761768
"""
762769
name = model_cls.__name__.lower()
@@ -766,7 +773,7 @@ def add_id_argument(ap: ArgumentParser, model_cls: ModelMetaclass):
766773
help="identifier of the {name} (integer)")
767774

768775

769-
def add_input_arguments(ap: ArgumentParser, model_cls: ModelMetaclass, too_complex_for_cli_keyvals: bool = False):
776+
def add_input_arguments(ap: ArgumentParser, model_cls, too_complex_for_cli_keyvals: bool = False):
770777
""" Adds the --input -i and --infile -I argument in a consistent way
771778
The -i argument takes key-value pairs to construct models rather then provide those in json via stdin (for simple models only)
772779
The -I argument specifies an input file to use in stead of stdin
@@ -780,7 +787,7 @@ def add_input_arguments(ap: ArgumentParser, model_cls: ModelMetaclass, too_compl
780787
if too_complex_for_cli_keyvals:
781788
return
782789
# else allow key-val --input
783-
fields = model_cls.__fields__.keys() # type: ignore
790+
fields = model_cls.model_fields.keys() # type: ignore
784791
ap.add_argument(
785792
'--input', '-i',
786793
action=ParseDict,
@@ -1172,14 +1179,14 @@ def enable_logging(logconf=None):
11721179

11731180

11741181
# helper function to instantiate the client
1175-
def make_amplipi(args: Namespace) -> AmpliPi:
1182+
def make_amplipi(args: Namespace, loop) -> AmpliPi:
11761183
""" Constructs the amplipi client
11771184
"""
11781185
endpoint: str = args.amplipi
11791186
timeout: int = args.timeout
11801187
# in shell modus we got frequent server-disconnected-errors - injecting this custom session avoids that
1181-
connector: TCPConnector = TCPConnector(force_close=True)
1182-
http_session: ClientSession = ClientSession(connector=connector)
1188+
connector: TCPConnector = TCPConnector(force_close=True, loop=loop)
1189+
http_session: ClientSession = ClientSession(connector=connector, loop=loop)
11831190
return AmpliPi(endpoint, timeout=timeout, http_session=http_session)
11841191

11851192

@@ -1198,10 +1205,10 @@ def main():
11981205
sys.exit(1)
11991206

12001207
enable_logging(logconf=args.logconf)
1201-
amplipi = make_amplipi(args)
12021208

12031209
# setup async wait construct for main routines
12041210
loop = asyncio.get_event_loop_policy().get_event_loop()
1211+
amplipi = make_amplipi(args, loop)
12051212
try:
12061213
# trigger the actual called action-function (async) and wait for it
12071214
loop.run_until_complete(

0 commit comments

Comments
 (0)