diff --git a/lib/jwt.rb b/lib/jwt.rb index 47c8580b..5edc3824 100644 --- a/lib/jwt.rb +++ b/lib/jwt.rb @@ -1,6 +1,7 @@ # frozen_string_literal: true -require 'base64' +require 'jwt/base64' +require 'jwt/json' require 'jwt/decode' require 'jwt/default_options' require 'jwt/encode' @@ -17,12 +18,13 @@ module JWT module_function def encode(payload, key, algorithm = 'HS256', header_fields = {}) - encoder = Encode.new payload, key, algorithm, header_fields - encoder.segments + Encode.new(payload: payload, + key: key, + algorithm: algorithm, + headers: header_fields).segments end def decode(jwt, key = nil, verify = true, options = {}, &keyfinder) - decoder = Decode.new(jwt, key, verify, DEFAULT_OPTIONS.merge(options), &keyfinder) - decoder.decode_segments + Decode.new(jwt, key, verify, DEFAULT_OPTIONS.merge(options), &keyfinder).decode_segments end end diff --git a/lib/jwt/base64.rb b/lib/jwt/base64.rb new file mode 100644 index 00000000..e69808b1 --- /dev/null +++ b/lib/jwt/base64.rb @@ -0,0 +1,19 @@ +# frozen_string_literal: true + +require 'base64' + +module JWT + # Base64 helpers + class Base64 + class << self + def url_encode(str) + ::Base64.encode64(str).tr('+/', '-_').gsub(/[\n=]/, '') + end + + def url_decode(str) + str += '=' * (4 - str.length.modulo(4)) + ::Base64.decode64(str.tr('-_', '+/')) + end + end + end +end diff --git a/lib/jwt/decode.rb b/lib/jwt/decode.rb index c6a6f543..37ee2e0e 100644 --- a/lib/jwt/decode.rb +++ b/lib/jwt/decode.rb @@ -8,11 +8,6 @@ module JWT # Decoding logic for JWT class Decode - def self.base64url_decode(str) - str += '=' * (4 - str.length.modulo(4)) - Base64.decode64(str.tr('-_', '+/')) - end - def initialize(jwt, key, verify, options, &keyfinder) raise(JWT::DecodeError, 'Nil JSON web token') unless jwt @jwt = jwt @@ -80,7 +75,7 @@ def segment_length end def decode_crypto - @signature = Decode.base64url_decode(@segments[2]) + @signature = JWT::Base64.url_decode(@segments[2]) end def header @@ -96,7 +91,7 @@ def signing_input end def parse_and_decode(segment) - JSON.parse(Decode.base64url_decode(segment)) + JSON.parse(JWT::Base64.url_decode(segment)) rescue JSON::ParserError raise JWT::DecodeError, 'Invalid segment encoding' end diff --git a/lib/jwt/encode.rb b/lib/jwt/encode.rb index 7e8f07d4..8769e7f4 100644 --- a/lib/jwt/encode.rb +++ b/lib/jwt/encode.rb @@ -1,51 +1,76 @@ # frozen_string_literal: true -require 'json' - # JWT::Encode module module JWT # Encoding logic for JWT class Encode - attr_reader :payload, :key, :algorithm, :header_fields, :segments + ALG_NONE = 'none'.freeze + ALG_KEY = 'alg'.freeze + EXP_KEY = 'exp'.freeze + EXP_KEYS = [EXP_KEY, EXP_KEY.to_sym].freeze - def self.base64url_encode(str) - Base64.encode64(str).tr('+/', '-_').gsub(/[\n=]/, '') + def initialize(options) + @payload = options[:payload] + @key = options[:key] + @algorithm = options[:algorithm] + @headers = options[:headers] end - def initialize(payload, key, algorithm, header_fields) - @payload = payload - @key = key - @algorithm = algorithm - @header_fields = header_fields - @segments = encode_segments + def segments + @segments ||= combine(encoded_header_and_payload, encoded_signature) end private + def validate_payload! + return unless @payload && @payload.is_a?(Hash) + + validate_exp! + end + + def validate_exp! + return if EXP_KEYS.all? { |key| !@payload.key?(key) || @payload[key].is_a?(Integer) } + + raise InvalidPayload, 'exp claim must be an integer' + end + def encoded_header - header = { 'alg' => @algorithm }.merge(@header_fields) - Encode.base64url_encode(JSON.generate(header)) + @encoded_header ||= encode_header end def encoded_payload - raise InvalidPayload, 'exp claim must be an integer' if @payload && @payload.is_a?(Hash) && @payload.key?('exp') && !@payload['exp'].is_a?(Integer) - Encode.base64url_encode(JSON.generate(@payload)) + @encoded_payload ||= encode_payload + end + + def encoded_signature + @encoded_signature ||= encode_signature + end + + def encoded_header_and_payload + @encoded_header_and_payload ||= combine(encoded_header, encoded_payload) + end + + def encode_header + encode(@headers.merge(ALG_KEY => @algorithm)) + end + + def encode_payload + validate_payload! + encode(@payload) + end + + def encode_signature + return '' if @algorithm == ALG_NONE + + JWT::Base64.url_encode(JWT::Signature.sign(@algorithm, encoded_header_and_payload, @key)) end - def encoded_signature(signing_input) - if @algorithm == 'none' - '' - else - signature = JWT::Signature.sign(@algorithm, signing_input, @key) - Encode.base64url_encode(signature) - end + def encode(data) + JWT::Base64.url_encode(JWT::JSON.generate(data)) end - def encode_segments - header = encoded_header - payload = encoded_payload - signature = encoded_signature([header, payload].join('.')) - [header, payload, signature].join('.') + def combine(*parts) + parts.join('.') end end end diff --git a/lib/jwt/json.rb b/lib/jwt/json.rb new file mode 100644 index 00000000..dffda08f --- /dev/null +++ b/lib/jwt/json.rb @@ -0,0 +1,18 @@ +# frozen_string_literal: true + +require 'json' + +module JWT + # JSON wrapper + class JSON + class << self + def generate(data) + ::JSON.generate(data) + end + + def parse(data) + ::JSON.parse(data) + end + end + end +end diff --git a/lib/jwt/jwk/rsa.rb b/lib/jwt/jwk/rsa.rb index e4fe552a..4691de75 100644 --- a/lib/jwt/jwk/rsa.rb +++ b/lib/jwt/jwk/rsa.rb @@ -27,16 +27,16 @@ def kid def export { kty: KTY, - n: Base64.urlsafe_encode64(public_key.n.to_s(BINARY), padding: false), - e: Base64.urlsafe_encode64(public_key.e.to_s(BINARY), padding: false), + n: ::Base64.urlsafe_encode64(public_key.n.to_s(BINARY), padding: false), + e: ::Base64.urlsafe_encode64(public_key.e.to_s(BINARY), padding: false), kid: kid } end def self.import(jwk_data) imported_key = OpenSSL::PKey::RSA.new - imported_key.set_key(OpenSSL::BN.new(Base64.urlsafe_decode64(jwk_data[:n]), BINARY), - OpenSSL::BN.new(Base64.urlsafe_decode64(jwk_data[:e]), BINARY), + imported_key.set_key(OpenSSL::BN.new(::Base64.urlsafe_decode64(jwk_data[:n]), BINARY), + OpenSSL::BN.new(::Base64.urlsafe_decode64(jwk_data[:e]), BINARY), nil) self.new(imported_key) end diff --git a/spec/jwt_spec.rb b/spec/jwt_spec.rb index 65085aff..cccd3c20 100644 --- a/spec/jwt_spec.rb +++ b/spec/jwt_spec.rb @@ -57,21 +57,33 @@ expect(header['alg']).to eq alg expect(jwt_payload).to eq payload end + end - it 'should display a better error message if payload exp is_a?(Time)' do - payload['exp'] = Time.now + context 'payload validation' do + subject { JWT.encode(payload, nil, 'none') } + let(:payload) { { 'exp' => exp } } - expect do - JWT.encode payload, nil, alg - end.to raise_error JWT::InvalidPayload + context 'when exp is given as a non Integer' do + let(:exp) { Time.now.to_i.to_s } + it 'raises an JWT::InvalidPayload error' do + expect { subject }.to raise_error(JWT::InvalidPayload, 'exp claim must be an integer') + end end - it 'should display a better error message if payload exp is not an Integer' do - payload['exp'] = Time.now.to_i.to_s + context 'when exp is given as an Integer' do + let(:exp) { 1234 } - expect do - JWT.encode payload, nil, alg - end.to raise_error JWT::InvalidPayload + it 'encodes the payload' do + expect(subject).to be_a(String) + end + end + + context 'when the key for exp is a symbol' do + let(:payload) { { :exp => 'NotAInteger' } } + + it 'raises an JWT::InvalidPayload error' do + expect { subject }.to raise_error(JWT::InvalidPayload, 'exp claim must be an integer') + end end end @@ -228,7 +240,7 @@ translated_alg = alg.gsub('PS', 'sha') valid_signature = data[:rsa_public].verify_pss( translated_alg, - JWT::Decode.base64url_decode(signature), + JWT::Base64.url_decode(signature), [header, body].join('.'), salt_length: :auto, mgf1_hash: translated_alg @@ -339,7 +351,7 @@ context 'Base64' do it 'urlsafe replace + / with - _' do allow(Base64).to receive(:encode64) { 'string+with/non+url-safe/characters_' } - expect(JWT::Encode.base64url_encode('foo')).to eq('string-with_non-url-safe_characters_') + expect(JWT::Base64.url_encode('foo')).to eq('string-with_non-url-safe_characters_') end end @@ -364,4 +376,12 @@ JWT.encode 'Hello World', 'secret' end.not_to raise_error end + + context 'when the alg value is given as a header parameter' do + + it 'does not override the actual algorithm used' do + headers = JSON.parse(::JWT::Base64.url_decode(JWT.encode('Hello World', 'secret', 'HS256', { alg: 'HS123'}).split('.').first)) + expect(headers['alg']).to eq('HS256') + end + end end