create: pkg jwk
This commit is contained in:
parent
d3d2549829
commit
a65a3fe0b3
120
jwk/jwk.go
Normal file
120
jwk/jwk.go
Normal file
@ -0,0 +1,120 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
|
||||
"git.daebt.dev/auth/algo"
|
||||
"git.daebt.dev/auth/algo/rs"
|
||||
)
|
||||
|
||||
type Token interface {
|
||||
Algo() (algo.Algorithm, error)
|
||||
KeyId() []byte
|
||||
}
|
||||
|
||||
type UseType string
|
||||
|
||||
const (
|
||||
UseSignature UseType = "sig"
|
||||
UseEncryption UseType = "enc"
|
||||
)
|
||||
|
||||
type Header struct {
|
||||
KeyId string `json:"kid,omitempty"` // KeyId уникальный идентификатор ключа
|
||||
KeyType algo.KeyType `json:"kty,omitempty"` // KeyType определяет криптографический алгоритм
|
||||
Use UseType `json:"use,omitempty"` // Use определяет использование ключа
|
||||
Algorithm algo.AlgorithmType `json:"alg,omitempty"` // Algorithm определяет алгоритм хеширования
|
||||
}
|
||||
|
||||
type List struct {
|
||||
v []Token
|
||||
}
|
||||
|
||||
func (l *List) SelectByKid(kid []byte) Token {
|
||||
val := string(kid)
|
||||
|
||||
for _, v := range l.v {
|
||||
if string(v.KeyId()) == val {
|
||||
return v
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (l *List) Range(f func(t Token) bool) {
|
||||
for _, v := range l.v {
|
||||
if !f(v) {
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (l *List) Write(w io.Writer) error {
|
||||
return json.NewEncoder(w).Encode(&struct {
|
||||
Keys []Token `json:"keys"`
|
||||
}{l.v})
|
||||
}
|
||||
|
||||
func (l *List) WriteBytes() ([]byte, error) {
|
||||
return json.Marshal(&struct {
|
||||
Keys []Token `json:"keys"`
|
||||
}{l.v})
|
||||
}
|
||||
|
||||
func NewList(t ...Token) *List {
|
||||
return &List{
|
||||
v: t,
|
||||
}
|
||||
}
|
||||
|
||||
// ParseList
|
||||
func ParseList(buf []byte, parse func(json.RawMessage) Token) (*List, error) {
|
||||
rows := &struct {
|
||||
Keys []json.RawMessage `json:"keys"`
|
||||
}{}
|
||||
|
||||
if err := json.Unmarshal(buf, rows); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
l := &List{
|
||||
v: []Token{},
|
||||
}
|
||||
|
||||
for _, v := range rows.Keys {
|
||||
if parse != nil {
|
||||
if val := parse(v); val != nil {
|
||||
l.v = append(l.v, val)
|
||||
continue
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
tkn Token
|
||||
info = &struct {
|
||||
KeyType algo.KeyType `json:"kty"`
|
||||
}{}
|
||||
)
|
||||
if err := json.Unmarshal(v, info); err != nil {
|
||||
return l, err
|
||||
}
|
||||
|
||||
switch info.KeyType {
|
||||
case rs.KeyRSA:
|
||||
tkn = new(TokenRSA)
|
||||
default:
|
||||
return l, errors.New("undefined key type")
|
||||
}
|
||||
|
||||
if err := json.Unmarshal(v, tkn); err != nil {
|
||||
return l, err
|
||||
}
|
||||
|
||||
l.v = append(l.v, tkn)
|
||||
}
|
||||
|
||||
return l, nil
|
||||
}
|
||||
168
jwk/jwk_test.go
Normal file
168
jwk/jwk_test.go
Normal file
@ -0,0 +1,168 @@
|
||||
package jwk_test
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"testing"
|
||||
|
||||
"git.daebt.dev/auth/algo"
|
||||
"git.daebt.dev/auth/jwk"
|
||||
"git.daebt.dev/auth/jwt"
|
||||
)
|
||||
|
||||
var rows = `{
|
||||
"keys": [
|
||||
{
|
||||
"kid": "MDE5NDUxZjEtZjc4OS03MmE2LTgzNmQtM2JjNjE0NmFkNzZh",
|
||||
"kty": "RSA",
|
||||
"use": "sig",
|
||||
"alg": "RS256",
|
||||
"e": "AQAB",
|
||||
"n": "zIWvl1OtwExnQ3HvoYk6bFRIlcCjzdv1yHazJfr6jxk6w-tCsIWEdtKsAPm3gmmFG-mTuHq-H53sahm6DD9YC5ZQjnvSYkBKv70Zw331_tg9VLbfJc-gN7kbD3xMQsucYD0973r7l9pEPH4Qw_I-BEKHMxlTmynStgKxnfyO6iPkL5jTzpUXlD4V9xqUoMY_uX3EpGwbJJKFJuphYX3jzJQ--tovQGGep7RgNeMEoWjyAkJ2yb7tiSWsw7qk6GO2z6NDmnc1UsdSBZ6Vg7BPUp8EINAdX1wbmB0-QH_vp0huM6lSY6NMBugwptQUe5UCaly9fN4kb26U0qglEoGB6w"
|
||||
}
|
||||
]
|
||||
}`
|
||||
|
||||
func TestParse(t *testing.T) {
|
||||
var data = []*struct {
|
||||
t string
|
||||
f func(t *jwt.Token, a algo.Algorithm) error
|
||||
}{
|
||||
{
|
||||
"eyJhbGciOiJSUzI1NiIsImtpZCI6Ik1ERTVORFV4WmpFdFpqYzRPUzAzTW1FMkxUZ3pObVF0TTJKak5qRTBObUZrTnpaaCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJodHRwczovLzAuZXhhbXBsZS5jb20iLCJpc3MiOiJodHRwczovL2dpdC5kYWVidC5kZXYiLCJzdWIiOiJleGFtcGxlOjAifQ.avJZ8_n7UpEPUIf1ZuwRwRxpAKRowdn-DrdlOTDtuI3A0elZzEdcO35fhkhe_gwpZlg-URQzxdFsUD6hgan0vMakkhYd8HgocD_iGK70blhay96v-PpATHvSQgi-9MSQHbbInLpHSftv6DzzZpIhNDGEW8agcmpyqJ9LHYF95rvaAGpNeuckTXOJrzcDsfQa8h2Vabyy9QaQ0elB-CYNyQg82K25yGnJ441q3duWX6XEgMKb58BJ63FteYIeh0GCX-GRK7Qj7cC4DaL3bJwjXLDfuqw6C5WJ43Xmbg4KnNKiub1SaLaE5LaKsZH_svMvH_ki8DKJxmJejb63xhdWzQ",
|
||||
func(t *jwt.Token, a algo.Algorithm) error {
|
||||
return t.Verify(a)
|
||||
},
|
||||
},
|
||||
{
|
||||
"eyJhbGciOiJSUzI1NiIsImtpZCI6Ik1ERTVORFV4WmpFdFpqYzRPUzAzTW1FMkxUZ3pObVF0TTJKak5qRTBObUZrTnpaaCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJodHRwczovLzEuZXhhbXBsZS5jb20iLCJleHAiOiIyMDI1LTAxLTExVDAxOjE5OjM3LjAxMjE1NjQxOCswMzowMCIsImlhdCI6IjIwMjUtMDEtMTFUMDA6MTk6MzcuMDEyMTU2NDE4KzAzOjAwIiwiaXNzIjoiaHR0cHM6Ly9naXQuZGFlYnQuZGV2Iiwic3ViIjoiZXhhbXBsZToxIn0.YL8rblFD-kelG_UYoEkaa7B7b0Em520croAc4PicKc-HmyE4n2pKB1JBL105No7XTFkCN4I9ddtZ5O0ffPAiGtWR3Mhorl8o4FBBnWFSQJdYWuu5UZxYl2G2_AK6gE9HdIeujZiim9aHfhOmdhWJ-pJUiGAQjOvNHOYTPgvBug2sPvnpbkbbQc0ZsvHh2AtQ6BoTDJpLiSY7q7sr7AiGaJUL_4xcUSh5cd7Zief62FFjOBDWe1HWfGDr0JCrumrEWkT11lCXrftYaucRJ9gnD-qTh9Mx9iLCpk3YOuwMnw1q1ZgZ5SGJ2iwGoQnw40oYwd9aLj_OpWamT8jVcpAeQ2g",
|
||||
func(t *jwt.Token, a algo.Algorithm) error {
|
||||
err := t.Verify(a)
|
||||
if errors.Is(err, algo.ErrInvalidSignature) {
|
||||
return nil
|
||||
}
|
||||
|
||||
return err
|
||||
},
|
||||
},
|
||||
{
|
||||
"eyJhbGciOiJSUzI1NiIsImtpZCI6Ik1ERTVORFV4WmpFdFpqYzRPUzAzTW1FMkxUSTVNVEV0TTJKak5qRTBObUZrTnpaaCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJodHRwczovLzIuZXhhbXBsZS5jb20iLCJpYXQiOiIyMDI1LTAxLTExVDAwOjE5OjM3LjAxMjE1NjQxOCswMzowMCIsImlzcyI6Imh0dHBzOi8vZ2l0LmRhZWJ0LmRldiIsIm5iZiI6IjIwMjUtMDEtMTFUMDA6MTk6MzcuMDEyMTU2NDE4KzAzOjAwIiwic3ViIjoiZXhhbXBsZToyIn0.Gg16e-pYVG0yokUTLf6M5-s2eirMvnIm_k-ebFNonlW9I5aSqc0PEI9e-Lp2eiagDqw0OSg8gt-V3xyhbU178W63VpxlqaOvxfjLWiW2Ql1qLlZUWDzEEfHVAdbyPUGYM06uVMv51ZwbnzdXTt5hlPK3dzV04yM5ISSsVBf9XhsPVBeUO8c0SKlnE4AdGE6nfjQKKKZ6eSkYJz5nRVWe5dbZC-vDAVYv5zj7L5iI7195tHWSlEQoAy9pYhPE_jVmtggsWBrY-U9hg3elrq3h2OhOKSUXT7g8zSLxOvSleg24qXaKP58trfCIks2xC-60G-6vx2e_XUI8hSp1suhOXw",
|
||||
func(t *jwt.Token, a algo.Algorithm) error {
|
||||
if a == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
return errors.New("a is not nil")
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
lst, err := jwk.ParseList([]byte(rows), nil)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
for i, v := range data {
|
||||
t.Run(fmt.Sprint(i), func(t *testing.T) {
|
||||
tkn, err := jwt.Parse(v.t)
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
kid, err := tkn.Header.GetKeyId()
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
|
||||
var alg algo.Algorithm
|
||||
key := lst.SelectByKid(kid)
|
||||
if key != nil {
|
||||
alg, err = key.Algo()
|
||||
if err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
}
|
||||
|
||||
if err := v.f(tkn, alg); err != nil {
|
||||
t.Fatal(err.Error())
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
var compareTests = []struct {
|
||||
a, b []byte
|
||||
i int
|
||||
}{
|
||||
{[]byte(""), []byte(""), 0},
|
||||
{[]byte("a"), []byte(""), 1},
|
||||
{[]byte(""), []byte("a"), -1},
|
||||
{[]byte("abc"), []byte("abc"), 0},
|
||||
{[]byte("abd"), []byte("abc"), 1},
|
||||
{[]byte("abc"), []byte("abd"), -1},
|
||||
{[]byte("ab"), []byte("abc"), -1},
|
||||
{[]byte("abc"), []byte("ab"), 1},
|
||||
{[]byte("x"), []byte("ab"), 1},
|
||||
{[]byte("ab"), []byte("x"), -1},
|
||||
{[]byte("x"), []byte("a"), 1},
|
||||
{[]byte("b"), []byte("x"), -1},
|
||||
// test runtime·memeq's chunked implementation
|
||||
{[]byte("abcdefgh"), []byte("abcdefgh"), 0},
|
||||
{[]byte("abcdefghi"), []byte("abcdefghi"), 0},
|
||||
{[]byte("abcdefghi"), []byte("abcdefghj"), -1},
|
||||
{[]byte("abcdefghj"), []byte("abcdefghi"), 1},
|
||||
{[]byte(" dahsdhajhdlj hadiuahdiuahsdupiu hh pauhd puhapduhapda"), []byte("abcdefghi"), 1},
|
||||
{[]byte(" dahsdhajhdlj hadiuahdiuahsdupiu hh pauhd puhapduhapda"), []byte(" dahsdhajhdlj hadiuahdiuahsdupiu hh pauhd puhapduhapda"), 0},
|
||||
// nil tests
|
||||
{nil, nil, 0},
|
||||
{[]byte(""), nil, 0},
|
||||
{nil, []byte(""), 0},
|
||||
{[]byte("a"), nil, 1},
|
||||
{nil, []byte("a"), -1},
|
||||
}
|
||||
|
||||
func BenchmarkEq(b *testing.B) {
|
||||
var data = []func(a, b []byte) bool{
|
||||
func(a, b []byte) bool {
|
||||
return string(a) == string(b)
|
||||
},
|
||||
func(a, b []byte) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i := 0; i < len(a); i++ {
|
||||
if a[i] != b[i] {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
},
|
||||
func(a, b []byte) bool {
|
||||
if len(a) != len(b) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i, v := range a {
|
||||
if b[i] != v {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
},
|
||||
}
|
||||
|
||||
for i, d := range data {
|
||||
b.Run(fmt.Sprint(i), func(b *testing.B) {
|
||||
for _, v := range compareTests {
|
||||
if d(v.a, v.b) && v.i != 0 {
|
||||
b.Fatal(string(v.a), " = ", string(v.b))
|
||||
}
|
||||
}
|
||||
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
117
jwk/rsa.go
Normal file
117
jwk/rsa.go
Normal file
@ -0,0 +1,117 @@
|
||||
package jwk
|
||||
|
||||
import (
|
||||
"crypto/rsa"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"math/big"
|
||||
|
||||
"git.daebt.dev/auth/algo"
|
||||
"git.daebt.dev/auth/algo/rs"
|
||||
)
|
||||
|
||||
type TokenRSA struct {
|
||||
*Header
|
||||
E string `json:"e"`
|
||||
N string `json:"n"`
|
||||
}
|
||||
|
||||
func (t *TokenRSA) encodeE(v int) {
|
||||
var (
|
||||
buf = []byte{}
|
||||
skp = true
|
||||
)
|
||||
|
||||
for i := 56; i > -1; i -= 8 {
|
||||
b := byte(v >> i)
|
||||
if skp && b == 0 {
|
||||
continue
|
||||
}
|
||||
skp = false
|
||||
buf = append(buf, b)
|
||||
}
|
||||
|
||||
t.E = base64.RawURLEncoding.EncodeToString(buf)
|
||||
}
|
||||
|
||||
func (t *TokenRSA) decodeE() (int, error) {
|
||||
if t.E == "" {
|
||||
return 0, errors.New("e is empty")
|
||||
}
|
||||
|
||||
buf, err := base64.RawURLEncoding.DecodeString(t.E)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
var (
|
||||
num int
|
||||
l = len(buf) - 1
|
||||
)
|
||||
for i, b := range buf {
|
||||
i = (l - i) * 8
|
||||
num = num | int(b)<<i
|
||||
i += 8
|
||||
}
|
||||
|
||||
return num, nil
|
||||
}
|
||||
|
||||
func (t *TokenRSA) Algo() (algo.Algorithm, error) {
|
||||
n, err := base64.RawURLEncoding.DecodeString(t.N)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
e, err := t.decodeE()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
o := rs.WithPub(&rsa.PublicKey{
|
||||
N: new(big.Int).SetBytes(n),
|
||||
E: e,
|
||||
})
|
||||
|
||||
switch t.Header.Algorithm {
|
||||
case rs.AlgorithmRS256:
|
||||
return rs.NewRS256(o)
|
||||
case rs.AlgorithmRS384:
|
||||
return rs.NewRS384(o)
|
||||
case rs.AlgorithmRS512:
|
||||
return rs.NewRS512(o)
|
||||
}
|
||||
|
||||
return nil, errors.New("undefined algorithm type")
|
||||
}
|
||||
|
||||
func (t *TokenRSA) KeyId() []byte {
|
||||
if t.Header.KeyId != "" {
|
||||
if val, err := base64.RawURLEncoding.DecodeString(t.Header.KeyId); err == nil {
|
||||
return val
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewRSA(kid []byte, v *rs.Algo) Token {
|
||||
pub := v.PublicKey()
|
||||
|
||||
t := &TokenRSA{
|
||||
Header: &Header{
|
||||
KeyId: "",
|
||||
KeyType: v.Key(),
|
||||
Use: UseSignature,
|
||||
Algorithm: v.Algo(),
|
||||
},
|
||||
E: "",
|
||||
N: base64.RawURLEncoding.EncodeToString(pub.N.Bytes()),
|
||||
}
|
||||
if len(kid) > 0 {
|
||||
t.Header.KeyId = base64.RawURLEncoding.EncodeToString(kid)
|
||||
}
|
||||
|
||||
t.encodeE(pub.E)
|
||||
return t
|
||||
}
|
||||
Loading…
x
Reference in New Issue
Block a user