Skip to content

Commit

Permalink
Clean up typing (#64)
Browse files Browse the repository at this point in the history
* clean up typing

* fix tests

* lint
blink1073 authored Jan 12, 2023
1 parent 88cc054 commit ac65980
Showing 10 changed files with 60 additions and 44 deletions.
2 changes: 1 addition & 1 deletion jupyter_events/cli.py
Original file line number Diff line number Diff line change
@@ -52,7 +52,7 @@ def main():
@click.command()
@click.argument("schema")
@click.pass_context
def validate(ctx: click.Context, schema: str):
def validate(ctx: click.Context, schema: str) -> int:
"""Validate a SCHEMA against Jupyter Event's meta schema.
SCHEMA can be a JSON/YAML string or filepath to a schema.
30 changes: 16 additions & 14 deletions jupyter_events/logger.py
Original file line number Diff line number Diff line change
@@ -8,14 +8,14 @@
import logging
import warnings
from datetime import datetime
from pathlib import PurePath
from typing import Callable, Optional, Union
from typing import Any, Callable, Coroutine, Optional, Union

from jsonschema import ValidationError
from pythonjsonlogger import jsonlogger # type:ignore
from traitlets import Dict, Instance, Set, default
from traitlets.config import Config, LoggingConfigurable

from .schema import SchemaType
from .schema_registry import SchemaRegistry
from .traits import Handlers
from .validators import JUPYTER_EVENTS_CORE_VALIDATOR
@@ -131,7 +131,7 @@ def get_handlers():
eventlogger_cfg = Config({"EventLogger": my_cfg})
super()._load_config(eventlogger_cfg, section_names=None, traits=None)

def register_event_schema(self, schema: Union[dict, str, PurePath]):
def register_event_schema(self, schema: SchemaType) -> None:
"""Register this schema with the schema registry.
Get this registered schema using the EventLogger.schema.get() method.
@@ -143,7 +143,7 @@ def register_event_schema(self, schema: Union[dict, str, PurePath]):
self._modified_listeners[key] = set()
self._unmodified_listeners[key] = set()

def register_handler(self, handler: logging.Handler):
def register_handler(self, handler: logging.Handler) -> None:
"""Register a new logging handler to the Event Logger.
All outgoing messages will be formatted as a JSON string.
@@ -164,7 +164,7 @@ def _skip_message(record, **kwargs):
if handler not in self.handlers:
self.handlers.append(handler)

def remove_handler(self, handler: logging.Handler):
def remove_handler(self, handler: logging.Handler) -> None:
"""Remove a logging handler from the logger and list of handlers."""
self._logger.removeHandler(handler)
if handler in self.handlers:
@@ -175,7 +175,7 @@ def add_modifier(
*,
schema_id: Union[str, None] = None,
modifier: Callable[[str, dict], dict],
):
) -> None:
"""Add a modifier (callable) to a registered event.
Parameters
@@ -249,8 +249,8 @@ def add_listener(
*,
modified: bool = True,
schema_id: Union[str, None] = None,
listener: Callable[["EventLogger", str, dict], None],
):
listener: Callable[["EventLogger", str, dict], Coroutine[Any, Any, None]],
) -> None:
"""Add a listener (callable) to a registered event.
Parameters
@@ -304,7 +304,7 @@ def remove_listener(
self,
*,
schema_id: Optional[str] = None,
listener: Callable[["EventLogger", str, dict], None],
listener: Callable[["EventLogger", str, dict], Coroutine[Any, Any, None]],
) -> None:
"""Remove a listener from an event or all events.
@@ -327,7 +327,9 @@ def remove_listener(
self._modified_listeners[schema_id].discard(listener)
self._unmodified_listeners[schema_id].discard(listener)

def emit(self, *, schema_id: str, data: dict, timestamp_override=None):
def emit(
self, *, schema_id: str, data: dict, timestamp_override: Optional[datetime] = None
) -> Optional[dict]:
"""
Record given event with schema has occurred.
@@ -351,7 +353,7 @@ def emit(self, *, schema_id: str, data: dict, timestamp_override=None):
and not self._modified_listeners[schema_id]
and not self._unmodified_listeners[schema_id]
):
return
return None

# If the schema hasn't been registered, raise a warning to make sure
# this was intended.
@@ -362,7 +364,7 @@ def emit(self, *, schema_id: str, data: dict, timestamp_override=None):
"`register_event_schema` method.",
SchemaNotRegistered,
)
return
return None

schema = self.schemas.get(schema_id)

@@ -400,7 +402,7 @@ def emit(self, *, schema_id: str, data: dict, timestamp_override=None):

# callback for removing from finished listeners
# from active listeners set.
def _listener_task_done(task: asyncio.Task):
def _listener_task_done(task: asyncio.Task) -> None:
# If an exception happens, log it to the main
# applications logger
err = task.exception()
@@ -429,7 +431,7 @@ def _listener_task_done(task: asyncio.Task):
self._active_listeners.add(task)

# Remove task from active listeners once its finished.
def _listener_task_done(task: asyncio.Task):
def _listener_task_done(task: asyncio.Task) -> None:
# If an exception happens, log it to the main
# applications logger
err = task.exception()
21 changes: 12 additions & 9 deletions jupyter_events/schema.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
"""Event schema objects."""
import json
from pathlib import Path, PurePath
from typing import Type, Union
from typing import Optional, Type, Union

from jsonschema import FormatChecker, validators
from jsonschema import FormatChecker, RefResolver, validators

try:
from jsonschema.protocols import Validator
@@ -34,6 +34,9 @@ class EventSchemaFileAbsent(Exception): # noqa
pass


SchemaType = Union[dict, str, PurePath]


class EventSchema:
"""A validated schema that can be used.
@@ -58,10 +61,10 @@ class EventSchema:

def __init__(
self,
schema: Union[dict, str, PurePath],
validator_class: Type[Validator] = validators.Draft7Validator, # type:ignore
schema: SchemaType,
validator_class: Type[Validator] = validators.Draft7Validator, # type:ignore[assignment]
format_checker: FormatChecker = draft7_format_checker,
resolver=None,
resolver: Optional[RefResolver] = None,
):
"""Initialize an event schema."""
_schema = self._load_schema(schema)
@@ -76,29 +79,29 @@ def __repr__(self):
return json.dumps(self._schema, indent=2)

@staticmethod
def _ensure_yaml_loaded(schema, was_str=False) -> None:
def _ensure_yaml_loaded(schema: SchemaType, was_str: bool = False) -> None:
"""Ensures schema was correctly loaded into a dictionary. Raises
EventSchemaLoadingError otherwise."""
if isinstance(schema, dict):
return

error_msg = "Could not deserialize schema into a dictionary."

def intended_as_path(schema):
def intended_as_path(schema: str) -> bool:
path = Path(schema)
return path.match("*.yml") or path.match("*.yaml") or path.match("*.json")

# detect whether the user specified a string but intended a PurePath to
# generate a more helpful error message
if was_str and intended_as_path(schema):
if was_str and intended_as_path(schema): # type:ignore[arg-type]
error_msg += " Paths to schema files must be explicitly wrapped in a Pathlib object."
else:
error_msg += " Double check the schema and ensure it is in the proper form."

raise EventSchemaLoadingError(error_msg)

@staticmethod
def _load_schema(schema: Union[dict, str, PurePath]) -> dict:
def _load_schema(schema: SchemaType) -> dict:
"""Load a JSON schema from different sources/data types.
`schema` could be a dictionary or serialized string representing the
4 changes: 2 additions & 2 deletions jupyter_events/schema_registry.py
Original file line number Diff line number Diff line change
@@ -15,15 +15,15 @@ def __init__(self, schemas: Optional[dict] = None):
"""Initialize the registry."""
self._schemas = schemas or {}

def __contains__(self, key: str):
def __contains__(self, key: str) -> bool:
"""Syntax sugar to check if a schema is found in the registry"""
return key in self._schemas

def __repr__(self) -> str:
"""The str repr of the registry."""
return ",\n".join([str(s) for s in self._schemas.values()])

def _add(self, schema_obj: EventSchema):
def _add(self, schema_obj: EventSchema) -> None:
if schema_obj.id in self._schemas:
msg = (
f"The schema, {schema_obj.id}, is already "
2 changes: 1 addition & 1 deletion jupyter_events/validators.py
Original file line number Diff line number Diff line change
@@ -46,7 +46,7 @@
)


def validate_schema(schema: dict):
def validate_schema(schema: dict) -> None:
"""Validate a schema dict."""
try:
# Validate the schema against Jupyter Events metaschema.
14 changes: 14 additions & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
@@ -133,6 +133,20 @@ exclude_lines = [
"@(abc\\.)?abstractmethod",
]

[tool.mypy]
check_untyped_defs = true
disallow_incomplete_defs = true
no_implicit_optional = true
pretty = true
show_error_context = true
show_error_codes = true
strict_equality = true
warn_unused_configs = true
warn_unused_ignores = true
warn_redundant_casts = true
explicit_package_bases = true
namespace_packages = true

[tool.black]
line-length = 100
skip-string-normalization = true
15 changes: 6 additions & 9 deletions tests/test_listeners.py
Original file line number Diff line number Diff line change
@@ -23,12 +23,11 @@ def jp_event_schemas(schema):

async def test_listener_function(jp_event_logger, schema):
event_logger = jp_event_logger
global listener_was_called
listener_was_called = False

async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None:
global listener_was_called
listener_was_called = True # type: ignore
nonlocal listener_was_called
listener_was_called = True

# Add the modifier
event_logger.add_listener(schema_id=schema.id, listener=my_listener)
@@ -41,12 +40,11 @@ async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None:

async def test_remove_listener_function(jp_event_logger, schema):
event_logger = jp_event_logger
global listener_was_called
listener_was_called = False

async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None:
global listener_was_called
listener_was_called = True # type: ignore
nonlocal listener_was_called
listener_was_called = True

# Add the modifier
event_logger.add_listener(schema_id=schema.id, listener=my_listener)
@@ -114,15 +112,14 @@ async def test_bad_listener_does_not_break_good_listener(jp_event_logger, schema
h = logging.StreamHandler(log_stream)
app_log.addHandler(h)

global listener_was_called
listener_was_called = False

async def listener_raise_exception(logger: EventLogger, schema_id: str, data: dict) -> None:
raise Exception("This failed") # noqa

async def my_listener(logger: EventLogger, schema_id: str, data: dict) -> None:
global listener_was_called
listener_was_called = True # type: ignore
nonlocal listener_was_called
listener_was_called = True

# Add a bad listener and a good listener and ensure that
# emitting still works and the bad listener's exception is is logged.
10 changes: 5 additions & 5 deletions tests/test_modifiers.py
Original file line number Diff line number Diff line change
@@ -54,20 +54,20 @@ def redact(self, schema_id: str, data: dict) -> dict:
assert output["username"] == "<masked>"


def test_bad_modifier_functions(jp_event_logger, schema: EventSchema):
def test_bad_modifier_functions(jp_event_logger: EventLogger, schema: EventSchema) -> None:
event_logger = jp_event_logger

def modifier_with_extra_args(schema_id: str, data: dict, unknown_arg: dict) -> dict:
return data

with pytest.raises(ModifierError):
event_logger.add_modifier(modifier=modifier_with_extra_args)
event_logger.add_modifier(modifier=modifier_with_extra_args) # type:ignore[arg-type]

# Ensure no modifier was added.
assert len(event_logger._modifiers[schema.id]) == 0


def test_bad_modifier_method(jp_event_logger, schema: EventSchema):
def test_bad_modifier_method(jp_event_logger: EventLogger, schema: EventSchema) -> None:
event_logger = jp_event_logger

class Redactor:
@@ -77,7 +77,7 @@ def redact(self, schema_id: str, data: dict, extra_args: dict) -> dict:
redactor = Redactor()

with pytest.raises(ModifierError):
event_logger.add_modifier(modifier=redactor.redact)
event_logger.add_modifier(modifier=redactor.redact) # type:ignore[arg-type]

# Ensure no modifier was added
assert len(event_logger._modifiers[schema.id]) == 0
@@ -90,7 +90,7 @@ def modifier_with_extra_args(event):
return event

with pytest.raises(ModifierError):
logger.add_modifier(modifier=modifier_with_extra_args)
logger.add_modifier(modifier=modifier_with_extra_args) # type:ignore[arg-type]


def test_remove_modifier(schema, jp_event_logger, jp_read_emitted_events):
2 changes: 1 addition & 1 deletion tests/test_schema.py
Original file line number Diff line number Diff line change
@@ -59,7 +59,7 @@ def test_string_intended_as_path():
def test_unrecognized_type():
"""Validation fails because file is not of valid type."""
with pytest.raises(EventSchemaUnrecognized):
EventSchema(9001)
EventSchema(9001) # type:ignore[arg-type]


def test_invalid_yaml():
4 changes: 2 additions & 2 deletions tests/utils.py
Original file line number Diff line number Diff line change
@@ -16,10 +16,10 @@ def get_event_data(event, schema, schema_id, version, unredacted_policies):
handler = logging.StreamHandler(sink)

e = EventLogger(handlers=[handler], unredacted_policies=unredacted_policies)
e.register_schema(schema)
e.register_event_schema(schema)

# Record event and read output
e.emit(schema_id, version, deepcopy(event))
e.emit(schema_id=schema_id, data=deepcopy(event))

recorded_event = json.loads(sink.getvalue())
return {key: value for key, value in recorded_event.items() if not key.startswith("__")}

0 comments on commit ac65980

Please sign in to comment.