"""JSON Web Signature."""
import argparse
import base64
import sys
from typing import (
Any,
Dict,
FrozenSet,
List,
Mapping,
Optional,
Tuple,
Type,
cast,
)
from OpenSSL import crypto
import josepy
from josepy import b64, errors, json_util, jwa
from josepy import jwk as jwk_mod
from josepy import util
[docs]class Signature(json_util.JSONObjectWithFields):
"""JWS Signature.
:ivar combined: Combined Header (protected and unprotected,
:class:`Header`).
:ivar unicode protected: JWS protected header (Jose Base-64 decoded).
:ivar header: JWS Unprotected Header (:class:`Header`).
:ivar str signature: The signature.
"""
header_cls = Header
combined: Header
__slots__ = ("combined",)
protected: str = json_util.field("protected", omitempty=True, default="")
header: Header = json_util.field(
"header", omitempty=True, default=header_cls(), decoder=header_cls.from_json
)
signature: bytes = json_util.field(
"signature", decoder=json_util.decode_b64jose, encoder=json_util.encode_b64jose
)
@protected.encoder # type: ignore
def protected(value: str) -> str:
# wrong type guess (Signature, not bytes) | pylint: disable=no-member
return json_util.encode_b64jose(value.encode("utf-8"))
@protected.decoder # type: ignore
def protected(value: str) -> str:
return json_util.decode_b64jose(value).decode("utf-8")
def __init__(self, **kwargs: Any) -> None:
if "combined" not in kwargs:
kwargs = self._with_combined(kwargs)
super().__init__(**kwargs)
assert self.combined.alg is not None
@classmethod
def _with_combined(cls, kwargs: Any) -> Dict[str, Any]:
assert "combined" not in kwargs
header = kwargs.get("header", cls._fields["header"].default)
protected = kwargs.get("protected", cls._fields["protected"].default)
if protected:
combined = header + cls.header_cls.json_loads(protected)
else:
combined = header
kwargs["combined"] = combined
return kwargs
@classmethod
def _msg(cls, protected: str, payload: bytes) -> bytes:
return b64.b64encode(protected.encode("utf-8")) + b"." + b64.b64encode(payload)
[docs] def verify(self, payload: bytes, key: Optional[josepy.JWK] = None) -> bool:
"""Verify.
:param bytes payload: Payload to verify.
:param JWK key: Key used for verification.
"""
actual_key: josepy.JWK = self.combined.find_key() if key is None else key
if not self.combined.alg:
raise josepy.Error("Not signature algorithm defined.")
return self.combined.alg.verify(
key=actual_key.key, sig=self.signature, msg=self._msg(self.protected, payload)
)
[docs] @classmethod
def sign(
cls,
payload: bytes,
key: josepy.JWK,
alg: josepy.JWASignature,
include_jwk: bool = True,
protect: FrozenSet = frozenset(),
**kwargs: Any,
) -> "Signature":
"""Sign.
:param bytes payload: Payload to sign.
:param JWK key: Key for signature.
:param JWASignature alg: Signature algorithm to use to sign.
:param bool include_jwk: If True, insert the JWK inside the signature headers.
:param FrozenSet protect: List of headers to protect.
"""
assert isinstance(key, alg.kty)
header_params = kwargs
header_params["alg"] = alg
if include_jwk:
header_params["jwk"] = key.public_key()
assert set(header_params).issubset(cls.header_cls._fields)
assert protect.issubset(cls.header_cls._fields)
protected_params = {}
for header in protect:
if header in header_params:
protected_params[header] = header_params.pop(header)
if protected_params:
protected = cls.header_cls(**protected_params).json_dumps()
else:
protected = ""
header = cls.header_cls(**header_params)
signature = alg.sign(key.key, cls._msg(protected, payload))
return cls(protected=protected, header=header, signature=signature)
[docs] def fields_to_partial_json(self) -> Dict[str, Any]:
fields = super().fields_to_partial_json()
if not fields["header"].not_omitted():
del fields["header"]
return fields
[docs] @classmethod
def fields_from_json(cls, jobj: Mapping[str, Any]) -> Dict[str, Any]:
fields = super().fields_from_json(jobj)
fields_with_combined = cls._with_combined(fields)
if "alg" not in fields_with_combined["combined"].not_omitted():
raise errors.DeserializationError("alg not present")
return fields_with_combined
[docs]class JWS(json_util.JSONObjectWithFields):
"""JSON Web Signature.
:ivar str payload: JWS Payload.
:ivar str signature: JWS Signatures.
"""
__slots__ = ("payload", "signatures")
payload: bytes
signatures: List[Signature]
signature_cls = Signature
[docs] def verify(self, key: Optional[josepy.JWK] = None) -> bool:
"""Verify."""
return all(sig.verify(self.payload, key) for sig in self.signatures)
[docs] @classmethod
def sign(cls, payload: bytes, **kwargs: Any) -> "JWS":
"""Sign."""
return cls(payload=payload, signatures=(cls.signature_cls.sign(payload=payload, **kwargs),))
@property
def signature(self) -> Signature:
"""Get a singleton signature.
:rtype: :class:`JWS.signature_cls`
"""
assert len(self.signatures) == 1
return self.signatures[0]
[docs] def to_compact(self) -> bytes:
"""Compact serialization.
:rtype: bytes
"""
assert len(self.signatures) == 1
assert "alg" not in self.signature.header.not_omitted()
# ... it must be in protected
return (
b64.b64encode(self.signature.protected.encode("utf-8"))
+ b"."
+ b64.b64encode(self.payload)
+ b"."
+ b64.b64encode(self.signature.signature)
)
[docs] @classmethod
def from_compact(cls, compact: bytes) -> "JWS":
"""Compact deserialization.
:param bytes compact:
"""
try:
protected, payload, signature = compact.split(b".")
except ValueError:
raise errors.DeserializationError(
"Compact JWS serialization should comprise of exactly" " 3 dot-separated components"
)
sig = cls.signature_cls(
protected=b64.b64decode(protected).decode("utf-8"), signature=b64.b64decode(signature)
)
return cls(payload=b64.b64decode(payload), signatures=(sig,))
[docs] def to_partial_json(self, flat: bool = True) -> Dict[str, Any]:
assert self.signatures
payload = json_util.encode_b64jose(self.payload)
if flat and len(self.signatures) == 1:
ret = self.signatures[0].to_partial_json()
ret["payload"] = payload
return ret
else:
return {
"payload": payload,
"signatures": self.signatures,
}
[docs] @classmethod
def from_json(cls, jobj: Mapping[str, Any]) -> "JWS":
if "signature" in jobj and "signatures" in jobj:
raise errors.DeserializationError("Flat mixed with non-flat")
elif "signature" in jobj: # flat
filtered = {key: value for key, value in jobj.items() if key != "payload"}
return cls(
payload=json_util.decode_b64jose(jobj["payload"]),
signatures=(cls.signature_cls.from_json(filtered),),
)
else:
return cls(
payload=json_util.decode_b64jose(jobj["payload"]),
signatures=tuple(cls.signature_cls.from_json(sig) for sig in jobj["signatures"]),
)
[docs]class CLI:
"""JWS CLI."""
[docs] @classmethod
def sign(cls, args: argparse.Namespace) -> None:
"""Sign."""
key = args.alg.kty.load(args.key.read())
args.key.close()
if args.protect is None:
args.protect = []
if args.compact:
args.protect.append("alg")
sig = JWS.sign(
payload=sys.stdin.read().encode(), key=key, alg=args.alg, protect=set(args.protect)
)
if args.compact:
print(sig.to_compact().decode("utf-8"))
else: # JSON
print(sig.json_dumps_pretty())
[docs] @classmethod
def verify(cls, args: argparse.Namespace) -> bool:
"""Verify."""
if args.compact:
sig = JWS.from_compact(sys.stdin.read().encode())
else: # JSON
try:
sig = cast(JWS, JWS.json_loads(sys.stdin.read()))
except errors.Error as error:
print(error)
return False
if args.key is not None:
assert args.kty is not None
key = args.kty.load(args.key.read()).public_key()
args.key.close()
else:
key = None
sys.stdout.write(sig.payload.decode())
return not sig.verify(key=key)
@classmethod
def _alg_type(cls, arg: Any) -> jwa.JWASignature:
return jwa.JWASignature.from_json(arg)
@classmethod
def _header_type(cls, arg: Any) -> Any:
assert arg in Signature.header_cls._fields
return arg
@classmethod
def _kty_type(cls, arg: Any) -> Type[jwk_mod.JWK]:
assert arg in jwk_mod.JWK.TYPES
return jwk_mod.JWK.TYPES[arg]
[docs] @classmethod
def run(cls, args: Optional[List[str]] = None) -> Optional[bool]:
"""Parse arguments and sign/verify."""
if args is None:
args = sys.argv[1:]
parser = argparse.ArgumentParser()
parser.add_argument("--compact", action="store_true")
subparsers = parser.add_subparsers()
parser_sign = subparsers.add_parser("sign")
parser_sign.set_defaults(func=cls.sign)
parser_sign.add_argument("-k", "--key", type=argparse.FileType("rb"), required=True)
parser_sign.add_argument("-a", "--alg", type=cls._alg_type, default=jwa.RS256)
parser_sign.add_argument("-p", "--protect", action="append", type=cls._header_type)
parser_verify = subparsers.add_parser("verify")
parser_verify.set_defaults(func=cls.verify)
parser_verify.add_argument("-k", "--key", type=argparse.FileType("rb"), required=False)
parser_verify.add_argument("--kty", type=cls._kty_type, required=False)
parsed = parser.parse_args(args)
return parsed.func(parsed)
if __name__ == "__main__":
exit(CLI.run()) # pragma: no cover