From a65a3fe0b33115fe3e4177a11b1a4c29a679f934 Mon Sep 17 00:00:00 2001 From: shchva Date: Sat, 11 Jan 2025 00:42:49 +0300 Subject: [PATCH] create: pkg jwk --- jwk/jwk.go | 120 ++++++++++++++++++++++++++++++++++ jwk/jwk_test.go | 168 ++++++++++++++++++++++++++++++++++++++++++++++++ jwk/rsa.go | 117 +++++++++++++++++++++++++++++++++ 3 files changed, 405 insertions(+) create mode 100644 jwk/jwk.go create mode 100644 jwk/jwk_test.go create mode 100644 jwk/rsa.go diff --git a/jwk/jwk.go b/jwk/jwk.go new file mode 100644 index 0000000..3d3bdef --- /dev/null +++ b/jwk/jwk.go @@ -0,0 +1,120 @@ +package jwk + +import ( + "encoding/json" + "errors" + "io" + + "git.daebt.dev/auth/algo" + "git.daebt.dev/auth/algo/rs" +) + +type Token interface { + Algo() (algo.Algorithm, error) + KeyId() []byte +} + +type UseType string + +const ( + UseSignature UseType = "sig" + UseEncryption UseType = "enc" +) + +type Header struct { + KeyId string `json:"kid,omitempty"` // KeyId уникальный идентификатор ключа + KeyType algo.KeyType `json:"kty,omitempty"` // KeyType определяет криптографический алгоритм + Use UseType `json:"use,omitempty"` // Use определяет использование ключа + Algorithm algo.AlgorithmType `json:"alg,omitempty"` // Algorithm определяет алгоритм хеширования +} + +type List struct { + v []Token +} + +func (l *List) SelectByKid(kid []byte) Token { + val := string(kid) + + for _, v := range l.v { + if string(v.KeyId()) == val { + return v + } + } + + return nil +} + +func (l *List) Range(f func(t Token) bool) { + for _, v := range l.v { + if !f(v) { + return + } + } +} + +func (l *List) Write(w io.Writer) error { + return json.NewEncoder(w).Encode(&struct { + Keys []Token `json:"keys"` + }{l.v}) +} + +func (l *List) WriteBytes() ([]byte, error) { + return json.Marshal(&struct { + Keys []Token `json:"keys"` + }{l.v}) +} + +func NewList(t ...Token) *List { + return &List{ + v: t, + } +} + +// ParseList +func ParseList(buf []byte, parse func(json.RawMessage) Token) (*List, error) { + rows := &struct { + Keys []json.RawMessage `json:"keys"` + }{} + + if err := json.Unmarshal(buf, rows); err != nil { + return nil, err + } + + l := &List{ + v: []Token{}, + } + + for _, v := range rows.Keys { + if parse != nil { + if val := parse(v); val != nil { + l.v = append(l.v, val) + continue + } + } + + var ( + tkn Token + info = &struct { + KeyType algo.KeyType `json:"kty"` + }{} + ) + if err := json.Unmarshal(v, info); err != nil { + return l, err + } + + switch info.KeyType { + case rs.KeyRSA: + tkn = new(TokenRSA) + default: + return l, errors.New("undefined key type") + } + + if err := json.Unmarshal(v, tkn); err != nil { + return l, err + } + + l.v = append(l.v, tkn) + } + + return l, nil +} diff --git a/jwk/jwk_test.go b/jwk/jwk_test.go new file mode 100644 index 0000000..667cd13 --- /dev/null +++ b/jwk/jwk_test.go @@ -0,0 +1,168 @@ +package jwk_test + +import ( + "errors" + "fmt" + "testing" + + "git.daebt.dev/auth/algo" + "git.daebt.dev/auth/jwk" + "git.daebt.dev/auth/jwt" +) + +var rows = `{ + "keys": [ + { + "kid": "MDE5NDUxZjEtZjc4OS03MmE2LTgzNmQtM2JjNjE0NmFkNzZh", + "kty": "RSA", + "use": "sig", + "alg": "RS256", + "e": "AQAB", + "n": "zIWvl1OtwExnQ3HvoYk6bFRIlcCjzdv1yHazJfr6jxk6w-tCsIWEdtKsAPm3gmmFG-mTuHq-H53sahm6DD9YC5ZQjnvSYkBKv70Zw331_tg9VLbfJc-gN7kbD3xMQsucYD0973r7l9pEPH4Qw_I-BEKHMxlTmynStgKxnfyO6iPkL5jTzpUXlD4V9xqUoMY_uX3EpGwbJJKFJuphYX3jzJQ--tovQGGep7RgNeMEoWjyAkJ2yb7tiSWsw7qk6GO2z6NDmnc1UsdSBZ6Vg7BPUp8EINAdX1wbmB0-QH_vp0huM6lSY6NMBugwptQUe5UCaly9fN4kb26U0qglEoGB6w" + } + ] +}` + +func TestParse(t *testing.T) { + var data = []*struct { + t string + f func(t *jwt.Token, a algo.Algorithm) error + }{ + { + "eyJhbGciOiJSUzI1NiIsImtpZCI6Ik1ERTVORFV4WmpFdFpqYzRPUzAzTW1FMkxUZ3pObVF0TTJKak5qRTBObUZrTnpaaCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJodHRwczovLzAuZXhhbXBsZS5jb20iLCJpc3MiOiJodHRwczovL2dpdC5kYWVidC5kZXYiLCJzdWIiOiJleGFtcGxlOjAifQ.avJZ8_n7UpEPUIf1ZuwRwRxpAKRowdn-DrdlOTDtuI3A0elZzEdcO35fhkhe_gwpZlg-URQzxdFsUD6hgan0vMakkhYd8HgocD_iGK70blhay96v-PpATHvSQgi-9MSQHbbInLpHSftv6DzzZpIhNDGEW8agcmpyqJ9LHYF95rvaAGpNeuckTXOJrzcDsfQa8h2Vabyy9QaQ0elB-CYNyQg82K25yGnJ441q3duWX6XEgMKb58BJ63FteYIeh0GCX-GRK7Qj7cC4DaL3bJwjXLDfuqw6C5WJ43Xmbg4KnNKiub1SaLaE5LaKsZH_svMvH_ki8DKJxmJejb63xhdWzQ", + func(t *jwt.Token, a algo.Algorithm) error { + return t.Verify(a) + }, + }, + { + "eyJhbGciOiJSUzI1NiIsImtpZCI6Ik1ERTVORFV4WmpFdFpqYzRPUzAzTW1FMkxUZ3pObVF0TTJKak5qRTBObUZrTnpaaCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJodHRwczovLzEuZXhhbXBsZS5jb20iLCJleHAiOiIyMDI1LTAxLTExVDAxOjE5OjM3LjAxMjE1NjQxOCswMzowMCIsImlhdCI6IjIwMjUtMDEtMTFUMDA6MTk6MzcuMDEyMTU2NDE4KzAzOjAwIiwiaXNzIjoiaHR0cHM6Ly9naXQuZGFlYnQuZGV2Iiwic3ViIjoiZXhhbXBsZToxIn0.YL8rblFD-kelG_UYoEkaa7B7b0Em520croAc4PicKc-HmyE4n2pKB1JBL105No7XTFkCN4I9ddtZ5O0ffPAiGtWR3Mhorl8o4FBBnWFSQJdYWuu5UZxYl2G2_AK6gE9HdIeujZiim9aHfhOmdhWJ-pJUiGAQjOvNHOYTPgvBug2sPvnpbkbbQc0ZsvHh2AtQ6BoTDJpLiSY7q7sr7AiGaJUL_4xcUSh5cd7Zief62FFjOBDWe1HWfGDr0JCrumrEWkT11lCXrftYaucRJ9gnD-qTh9Mx9iLCpk3YOuwMnw1q1ZgZ5SGJ2iwGoQnw40oYwd9aLj_OpWamT8jVcpAeQ2g", + func(t *jwt.Token, a algo.Algorithm) error { + err := t.Verify(a) + if errors.Is(err, algo.ErrInvalidSignature) { + return nil + } + + return err + }, + }, + { + "eyJhbGciOiJSUzI1NiIsImtpZCI6Ik1ERTVORFV4WmpFdFpqYzRPUzAzTW1FMkxUSTVNVEV0TTJKak5qRTBObUZrTnpaaCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJodHRwczovLzIuZXhhbXBsZS5jb20iLCJpYXQiOiIyMDI1LTAxLTExVDAwOjE5OjM3LjAxMjE1NjQxOCswMzowMCIsImlzcyI6Imh0dHBzOi8vZ2l0LmRhZWJ0LmRldiIsIm5iZiI6IjIwMjUtMDEtMTFUMDA6MTk6MzcuMDEyMTU2NDE4KzAzOjAwIiwic3ViIjoiZXhhbXBsZToyIn0.Gg16e-pYVG0yokUTLf6M5-s2eirMvnIm_k-ebFNonlW9I5aSqc0PEI9e-Lp2eiagDqw0OSg8gt-V3xyhbU178W63VpxlqaOvxfjLWiW2Ql1qLlZUWDzEEfHVAdbyPUGYM06uVMv51ZwbnzdXTt5hlPK3dzV04yM5ISSsVBf9XhsPVBeUO8c0SKlnE4AdGE6nfjQKKKZ6eSkYJz5nRVWe5dbZC-vDAVYv5zj7L5iI7195tHWSlEQoAy9pYhPE_jVmtggsWBrY-U9hg3elrq3h2OhOKSUXT7g8zSLxOvSleg24qXaKP58trfCIks2xC-60G-6vx2e_XUI8hSp1suhOXw", + func(t *jwt.Token, a algo.Algorithm) error { + if a == nil { + return nil + } + + return errors.New("a is not nil") + }, + }, + } + + lst, err := jwk.ParseList([]byte(rows), nil) + if err != nil { + t.Fatal(err.Error()) + } + + for i, v := range data { + t.Run(fmt.Sprint(i), func(t *testing.T) { + tkn, err := jwt.Parse(v.t) + if err != nil { + t.Fatal(err.Error()) + } + + kid, err := tkn.Header.GetKeyId() + if err != nil { + t.Fatal(err.Error()) + } + + var alg algo.Algorithm + key := lst.SelectByKid(kid) + if key != nil { + alg, err = key.Algo() + if err != nil { + t.Fatal(err.Error()) + } + } + + if err := v.f(tkn, alg); err != nil { + t.Fatal(err.Error()) + } + }) + } +} + +var compareTests = []struct { + a, b []byte + i int +}{ + {[]byte(""), []byte(""), 0}, + {[]byte("a"), []byte(""), 1}, + {[]byte(""), []byte("a"), -1}, + {[]byte("abc"), []byte("abc"), 0}, + {[]byte("abd"), []byte("abc"), 1}, + {[]byte("abc"), []byte("abd"), -1}, + {[]byte("ab"), []byte("abc"), -1}, + {[]byte("abc"), []byte("ab"), 1}, + {[]byte("x"), []byte("ab"), 1}, + {[]byte("ab"), []byte("x"), -1}, + {[]byte("x"), []byte("a"), 1}, + {[]byte("b"), []byte("x"), -1}, + // test runtime·memeq's chunked implementation + {[]byte("abcdefgh"), []byte("abcdefgh"), 0}, + {[]byte("abcdefghi"), []byte("abcdefghi"), 0}, + {[]byte("abcdefghi"), []byte("abcdefghj"), -1}, + {[]byte("abcdefghj"), []byte("abcdefghi"), 1}, + {[]byte(" dahsdhajhdlj hadiuahdiuahsdupiu hh pauhd puhapduhapda"), []byte("abcdefghi"), 1}, + {[]byte(" dahsdhajhdlj hadiuahdiuahsdupiu hh pauhd puhapduhapda"), []byte(" dahsdhajhdlj hadiuahdiuahsdupiu hh pauhd puhapduhapda"), 0}, + // nil tests + {nil, nil, 0}, + {[]byte(""), nil, 0}, + {nil, []byte(""), 0}, + {[]byte("a"), nil, 1}, + {nil, []byte("a"), -1}, +} + +func BenchmarkEq(b *testing.B) { + var data = []func(a, b []byte) bool{ + func(a, b []byte) bool { + return string(a) == string(b) + }, + func(a, b []byte) bool { + if len(a) != len(b) { + return false + } + + for i := 0; i < len(a); i++ { + if a[i] != b[i] { + return false + } + } + + return true + }, + func(a, b []byte) bool { + if len(a) != len(b) { + return false + } + + for i, v := range a { + if b[i] != v { + return false + } + } + + return true + }, + } + + for i, d := range data { + b.Run(fmt.Sprint(i), func(b *testing.B) { + for _, v := range compareTests { + if d(v.a, v.b) && v.i != 0 { + b.Fatal(string(v.a), " = ", string(v.b)) + } + } + + }) + } + +} diff --git a/jwk/rsa.go b/jwk/rsa.go new file mode 100644 index 0000000..802dc1e --- /dev/null +++ b/jwk/rsa.go @@ -0,0 +1,117 @@ +package jwk + +import ( + "crypto/rsa" + "encoding/base64" + "errors" + "math/big" + + "git.daebt.dev/auth/algo" + "git.daebt.dev/auth/algo/rs" +) + +type TokenRSA struct { + *Header + E string `json:"e"` + N string `json:"n"` +} + +func (t *TokenRSA) encodeE(v int) { + var ( + buf = []byte{} + skp = true + ) + + for i := 56; i > -1; i -= 8 { + b := byte(v >> i) + if skp && b == 0 { + continue + } + skp = false + buf = append(buf, b) + } + + t.E = base64.RawURLEncoding.EncodeToString(buf) +} + +func (t *TokenRSA) decodeE() (int, error) { + if t.E == "" { + return 0, errors.New("e is empty") + } + + buf, err := base64.RawURLEncoding.DecodeString(t.E) + if err != nil { + return 0, err + } + + var ( + num int + l = len(buf) - 1 + ) + for i, b := range buf { + i = (l - i) * 8 + num = num | int(b)< 0 { + t.Header.KeyId = base64.RawURLEncoding.EncodeToString(kid) + } + + t.encodeE(pub.E) + return t +}