Source code for josepy.jws

"""JSON Web Signature."""
import argparse
import base64
import sys
from typing import (
    Any,
    Dict,
    FrozenSet,
    List,
    Mapping,
    Optional,
    Tuple,
    Type,
    cast,
)

from OpenSSL import crypto

import josepy
from josepy import b64, errors, json_util, jwa
from josepy import jwk as jwk_mod
from josepy import util


[docs]class MediaType: """MediaType field encoder/decoder.""" PREFIX = "application/" """MIME Media Type and Content Type prefix."""
[docs] @classmethod def decode(cls, value: str) -> str: """Decoder.""" # 4.1.10 if "/" not in value: if ";" in value: raise errors.DeserializationError("Unexpected semi-colon") return cls.PREFIX + value return value
[docs] @classmethod def encode(cls, value: str) -> str: """Encoder.""" # 4.1.10 if ";" not in value: assert value.startswith(cls.PREFIX) return value[len(cls.PREFIX) :] return value
[docs]class Signature(json_util.JSONObjectWithFields): """JWS Signature. :ivar combined: Combined Header (protected and unprotected, :class:`Header`). :ivar unicode protected: JWS protected header (Jose Base-64 decoded). :ivar header: JWS Unprotected Header (:class:`Header`). :ivar str signature: The signature. """ header_cls = Header combined: Header __slots__ = ("combined",) protected: str = json_util.field("protected", omitempty=True, default="") header: Header = json_util.field( "header", omitempty=True, default=header_cls(), decoder=header_cls.from_json ) signature: bytes = json_util.field( "signature", decoder=json_util.decode_b64jose, encoder=json_util.encode_b64jose ) @protected.encoder # type: ignore def protected(value: str) -> str: # wrong type guess (Signature, not bytes) | pylint: disable=no-member return json_util.encode_b64jose(value.encode("utf-8")) @protected.decoder # type: ignore def protected(value: str) -> str: return json_util.decode_b64jose(value).decode("utf-8") def __init__(self, **kwargs: Any) -> None: if "combined" not in kwargs: kwargs = self._with_combined(kwargs) super().__init__(**kwargs) assert self.combined.alg is not None @classmethod def _with_combined(cls, kwargs: Any) -> Dict[str, Any]: assert "combined" not in kwargs header = kwargs.get("header", cls._fields["header"].default) protected = kwargs.get("protected", cls._fields["protected"].default) if protected: combined = header + cls.header_cls.json_loads(protected) else: combined = header kwargs["combined"] = combined return kwargs @classmethod def _msg(cls, protected: str, payload: bytes) -> bytes: return b64.b64encode(protected.encode("utf-8")) + b"." + b64.b64encode(payload)
[docs] def verify(self, payload: bytes, key: Optional[josepy.JWK] = None) -> bool: """Verify. :param bytes payload: Payload to verify. :param JWK key: Key used for verification. """ actual_key: josepy.JWK = self.combined.find_key() if key is None else key if not self.combined.alg: raise josepy.Error("Not signature algorithm defined.") return self.combined.alg.verify( key=actual_key.key, sig=self.signature, msg=self._msg(self.protected, payload) )
[docs] @classmethod def sign( cls, payload: bytes, key: josepy.JWK, alg: josepy.JWASignature, include_jwk: bool = True, protect: FrozenSet = frozenset(), **kwargs: Any, ) -> "Signature": """Sign. :param bytes payload: Payload to sign. :param JWK key: Key for signature. :param JWASignature alg: Signature algorithm to use to sign. :param bool include_jwk: If True, insert the JWK inside the signature headers. :param FrozenSet protect: List of headers to protect. """ assert isinstance(key, alg.kty) header_params = kwargs header_params["alg"] = alg if include_jwk: header_params["jwk"] = key.public_key() assert set(header_params).issubset(cls.header_cls._fields) assert protect.issubset(cls.header_cls._fields) protected_params = {} for header in protect: if header in header_params: protected_params[header] = header_params.pop(header) if protected_params: protected = cls.header_cls(**protected_params).json_dumps() else: protected = "" header = cls.header_cls(**header_params) signature = alg.sign(key.key, cls._msg(protected, payload)) return cls(protected=protected, header=header, signature=signature)
[docs] def fields_to_partial_json(self) -> Dict[str, Any]: fields = super().fields_to_partial_json() if not fields["header"].not_omitted(): del fields["header"] return fields
[docs] @classmethod def fields_from_json(cls, jobj: Mapping[str, Any]) -> Dict[str, Any]: fields = super().fields_from_json(jobj) fields_with_combined = cls._with_combined(fields) if "alg" not in fields_with_combined["combined"].not_omitted(): raise errors.DeserializationError("alg not present") return fields_with_combined
[docs]class JWS(json_util.JSONObjectWithFields): """JSON Web Signature. :ivar str payload: JWS Payload. :ivar str signature: JWS Signatures. """ __slots__ = ("payload", "signatures") payload: bytes signatures: List[Signature] signature_cls = Signature
[docs] def verify(self, key: Optional[josepy.JWK] = None) -> bool: """Verify.""" return all(sig.verify(self.payload, key) for sig in self.signatures)
[docs] @classmethod def sign(cls, payload: bytes, **kwargs: Any) -> "JWS": """Sign.""" return cls(payload=payload, signatures=(cls.signature_cls.sign(payload=payload, **kwargs),))
@property def signature(self) -> Signature: """Get a singleton signature. :rtype: :class:`JWS.signature_cls` """ assert len(self.signatures) == 1 return self.signatures[0]
[docs] def to_compact(self) -> bytes: """Compact serialization. :rtype: bytes """ assert len(self.signatures) == 1 assert "alg" not in self.signature.header.not_omitted() # ... it must be in protected return ( b64.b64encode(self.signature.protected.encode("utf-8")) + b"." + b64.b64encode(self.payload) + b"." + b64.b64encode(self.signature.signature) )
[docs] @classmethod def from_compact(cls, compact: bytes) -> "JWS": """Compact deserialization. :param bytes compact: """ try: protected, payload, signature = compact.split(b".") except ValueError: raise errors.DeserializationError( "Compact JWS serialization should comprise of exactly" " 3 dot-separated components" ) sig = cls.signature_cls( protected=b64.b64decode(protected).decode("utf-8"), signature=b64.b64decode(signature) ) return cls(payload=b64.b64decode(payload), signatures=(sig,))
[docs] def to_partial_json(self, flat: bool = True) -> Dict[str, Any]: assert self.signatures payload = json_util.encode_b64jose(self.payload) if flat and len(self.signatures) == 1: ret = self.signatures[0].to_partial_json() ret["payload"] = payload return ret else: return { "payload": payload, "signatures": self.signatures, }
[docs] @classmethod def from_json(cls, jobj: Mapping[str, Any]) -> "JWS": if "signature" in jobj and "signatures" in jobj: raise errors.DeserializationError("Flat mixed with non-flat") elif "signature" in jobj: # flat filtered = {key: value for key, value in jobj.items() if key != "payload"} return cls( payload=json_util.decode_b64jose(jobj["payload"]), signatures=(cls.signature_cls.from_json(filtered),), ) else: return cls( payload=json_util.decode_b64jose(jobj["payload"]), signatures=tuple(cls.signature_cls.from_json(sig) for sig in jobj["signatures"]), )
[docs]class CLI: """JWS CLI."""
[docs] @classmethod def sign(cls, args: argparse.Namespace) -> None: """Sign.""" key = args.alg.kty.load(args.key.read()) args.key.close() if args.protect is None: args.protect = [] if args.compact: args.protect.append("alg") sig = JWS.sign( payload=sys.stdin.read().encode(), key=key, alg=args.alg, protect=set(args.protect) ) if args.compact: print(sig.to_compact().decode("utf-8")) else: # JSON print(sig.json_dumps_pretty())
[docs] @classmethod def verify(cls, args: argparse.Namespace) -> bool: """Verify.""" if args.compact: sig = JWS.from_compact(sys.stdin.read().encode()) else: # JSON try: sig = cast(JWS, JWS.json_loads(sys.stdin.read())) except errors.Error as error: print(error) return False if args.key is not None: assert args.kty is not None key = args.kty.load(args.key.read()).public_key() args.key.close() else: key = None sys.stdout.write(sig.payload.decode()) return not sig.verify(key=key)
@classmethod def _alg_type(cls, arg: Any) -> jwa.JWASignature: return jwa.JWASignature.from_json(arg) @classmethod def _header_type(cls, arg: Any) -> Any: assert arg in Signature.header_cls._fields return arg @classmethod def _kty_type(cls, arg: Any) -> Type[jwk_mod.JWK]: assert arg in jwk_mod.JWK.TYPES return jwk_mod.JWK.TYPES[arg]
[docs] @classmethod def run(cls, args: Optional[List[str]] = None) -> Optional[bool]: """Parse arguments and sign/verify.""" if args is None: args = sys.argv[1:] parser = argparse.ArgumentParser() parser.add_argument("--compact", action="store_true") subparsers = parser.add_subparsers() parser_sign = subparsers.add_parser("sign") parser_sign.set_defaults(func=cls.sign) parser_sign.add_argument("-k", "--key", type=argparse.FileType("rb"), required=True) parser_sign.add_argument("-a", "--alg", type=cls._alg_type, default=jwa.RS256) parser_sign.add_argument("-p", "--protect", action="append", type=cls._header_type) parser_verify = subparsers.add_parser("verify") parser_verify.set_defaults(func=cls.verify) parser_verify.add_argument("-k", "--key", type=argparse.FileType("rb"), required=False) parser_verify.add_argument("--kty", type=cls._kty_type, required=False) parsed = parser.parse_args(args) return parsed.func(parsed)
if __name__ == "__main__": exit(CLI.run()) # pragma: no cover