create: pkg jwk

This commit is contained in:
shchva 2025-01-11 00:42:49 +03:00
parent d3d2549829
commit a65a3fe0b3
3 changed files with 405 additions and 0 deletions

120
jwk/jwk.go Normal file
View File

@ -0,0 +1,120 @@
package jwk
import (
"encoding/json"
"errors"
"io"
"git.daebt.dev/auth/algo"
"git.daebt.dev/auth/algo/rs"
)
type Token interface {
Algo() (algo.Algorithm, error)
KeyId() []byte
}
type UseType string
const (
UseSignature UseType = "sig"
UseEncryption UseType = "enc"
)
type Header struct {
KeyId string `json:"kid,omitempty"` // KeyId уникальный идентификатор ключа
KeyType algo.KeyType `json:"kty,omitempty"` // KeyType определяет криптографический алгоритм
Use UseType `json:"use,omitempty"` // Use определяет использование ключа
Algorithm algo.AlgorithmType `json:"alg,omitempty"` // Algorithm определяет алгоритм хеширования
}
type List struct {
v []Token
}
func (l *List) SelectByKid(kid []byte) Token {
val := string(kid)
for _, v := range l.v {
if string(v.KeyId()) == val {
return v
}
}
return nil
}
func (l *List) Range(f func(t Token) bool) {
for _, v := range l.v {
if !f(v) {
return
}
}
}
func (l *List) Write(w io.Writer) error {
return json.NewEncoder(w).Encode(&struct {
Keys []Token `json:"keys"`
}{l.v})
}
func (l *List) WriteBytes() ([]byte, error) {
return json.Marshal(&struct {
Keys []Token `json:"keys"`
}{l.v})
}
func NewList(t ...Token) *List {
return &List{
v: t,
}
}
// ParseList
func ParseList(buf []byte, parse func(json.RawMessage) Token) (*List, error) {
rows := &struct {
Keys []json.RawMessage `json:"keys"`
}{}
if err := json.Unmarshal(buf, rows); err != nil {
return nil, err
}
l := &List{
v: []Token{},
}
for _, v := range rows.Keys {
if parse != nil {
if val := parse(v); val != nil {
l.v = append(l.v, val)
continue
}
}
var (
tkn Token
info = &struct {
KeyType algo.KeyType `json:"kty"`
}{}
)
if err := json.Unmarshal(v, info); err != nil {
return l, err
}
switch info.KeyType {
case rs.KeyRSA:
tkn = new(TokenRSA)
default:
return l, errors.New("undefined key type")
}
if err := json.Unmarshal(v, tkn); err != nil {
return l, err
}
l.v = append(l.v, tkn)
}
return l, nil
}

168
jwk/jwk_test.go Normal file
View File

@ -0,0 +1,168 @@
package jwk_test
import (
"errors"
"fmt"
"testing"
"git.daebt.dev/auth/algo"
"git.daebt.dev/auth/jwk"
"git.daebt.dev/auth/jwt"
)
var rows = `{
"keys": [
{
"kid": "MDE5NDUxZjEtZjc4OS03MmE2LTgzNmQtM2JjNjE0NmFkNzZh",
"kty": "RSA",
"use": "sig",
"alg": "RS256",
"e": "AQAB",
"n": "zIWvl1OtwExnQ3HvoYk6bFRIlcCjzdv1yHazJfr6jxk6w-tCsIWEdtKsAPm3gmmFG-mTuHq-H53sahm6DD9YC5ZQjnvSYkBKv70Zw331_tg9VLbfJc-gN7kbD3xMQsucYD0973r7l9pEPH4Qw_I-BEKHMxlTmynStgKxnfyO6iPkL5jTzpUXlD4V9xqUoMY_uX3EpGwbJJKFJuphYX3jzJQ--tovQGGep7RgNeMEoWjyAkJ2yb7tiSWsw7qk6GO2z6NDmnc1UsdSBZ6Vg7BPUp8EINAdX1wbmB0-QH_vp0huM6lSY6NMBugwptQUe5UCaly9fN4kb26U0qglEoGB6w"
}
]
}`
func TestParse(t *testing.T) {
var data = []*struct {
t string
f func(t *jwt.Token, a algo.Algorithm) error
}{
{
"eyJhbGciOiJSUzI1NiIsImtpZCI6Ik1ERTVORFV4WmpFdFpqYzRPUzAzTW1FMkxUZ3pObVF0TTJKak5qRTBObUZrTnpaaCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJodHRwczovLzAuZXhhbXBsZS5jb20iLCJpc3MiOiJodHRwczovL2dpdC5kYWVidC5kZXYiLCJzdWIiOiJleGFtcGxlOjAifQ.avJZ8_n7UpEPUIf1ZuwRwRxpAKRowdn-DrdlOTDtuI3A0elZzEdcO35fhkhe_gwpZlg-URQzxdFsUD6hgan0vMakkhYd8HgocD_iGK70blhay96v-PpATHvSQgi-9MSQHbbInLpHSftv6DzzZpIhNDGEW8agcmpyqJ9LHYF95rvaAGpNeuckTXOJrzcDsfQa8h2Vabyy9QaQ0elB-CYNyQg82K25yGnJ441q3duWX6XEgMKb58BJ63FteYIeh0GCX-GRK7Qj7cC4DaL3bJwjXLDfuqw6C5WJ43Xmbg4KnNKiub1SaLaE5LaKsZH_svMvH_ki8DKJxmJejb63xhdWzQ",
func(t *jwt.Token, a algo.Algorithm) error {
return t.Verify(a)
},
},
{
"eyJhbGciOiJSUzI1NiIsImtpZCI6Ik1ERTVORFV4WmpFdFpqYzRPUzAzTW1FMkxUZ3pObVF0TTJKak5qRTBObUZrTnpaaCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJodHRwczovLzEuZXhhbXBsZS5jb20iLCJleHAiOiIyMDI1LTAxLTExVDAxOjE5OjM3LjAxMjE1NjQxOCswMzowMCIsImlhdCI6IjIwMjUtMDEtMTFUMDA6MTk6MzcuMDEyMTU2NDE4KzAzOjAwIiwiaXNzIjoiaHR0cHM6Ly9naXQuZGFlYnQuZGV2Iiwic3ViIjoiZXhhbXBsZToxIn0.YL8rblFD-kelG_UYoEkaa7B7b0Em520croAc4PicKc-HmyE4n2pKB1JBL105No7XTFkCN4I9ddtZ5O0ffPAiGtWR3Mhorl8o4FBBnWFSQJdYWuu5UZxYl2G2_AK6gE9HdIeujZiim9aHfhOmdhWJ-pJUiGAQjOvNHOYTPgvBug2sPvnpbkbbQc0ZsvHh2AtQ6BoTDJpLiSY7q7sr7AiGaJUL_4xcUSh5cd7Zief62FFjOBDWe1HWfGDr0JCrumrEWkT11lCXrftYaucRJ9gnD-qTh9Mx9iLCpk3YOuwMnw1q1ZgZ5SGJ2iwGoQnw40oYwd9aLj_OpWamT8jVcpAeQ2g",
func(t *jwt.Token, a algo.Algorithm) error {
err := t.Verify(a)
if errors.Is(err, algo.ErrInvalidSignature) {
return nil
}
return err
},
},
{
"eyJhbGciOiJSUzI1NiIsImtpZCI6Ik1ERTVORFV4WmpFdFpqYzRPUzAzTW1FMkxUSTVNVEV0TTJKak5qRTBObUZrTnpaaCIsInR5cCI6IkpXVCJ9.eyJhdWQiOiJodHRwczovLzIuZXhhbXBsZS5jb20iLCJpYXQiOiIyMDI1LTAxLTExVDAwOjE5OjM3LjAxMjE1NjQxOCswMzowMCIsImlzcyI6Imh0dHBzOi8vZ2l0LmRhZWJ0LmRldiIsIm5iZiI6IjIwMjUtMDEtMTFUMDA6MTk6MzcuMDEyMTU2NDE4KzAzOjAwIiwic3ViIjoiZXhhbXBsZToyIn0.Gg16e-pYVG0yokUTLf6M5-s2eirMvnIm_k-ebFNonlW9I5aSqc0PEI9e-Lp2eiagDqw0OSg8gt-V3xyhbU178W63VpxlqaOvxfjLWiW2Ql1qLlZUWDzEEfHVAdbyPUGYM06uVMv51ZwbnzdXTt5hlPK3dzV04yM5ISSsVBf9XhsPVBeUO8c0SKlnE4AdGE6nfjQKKKZ6eSkYJz5nRVWe5dbZC-vDAVYv5zj7L5iI7195tHWSlEQoAy9pYhPE_jVmtggsWBrY-U9hg3elrq3h2OhOKSUXT7g8zSLxOvSleg24qXaKP58trfCIks2xC-60G-6vx2e_XUI8hSp1suhOXw",
func(t *jwt.Token, a algo.Algorithm) error {
if a == nil {
return nil
}
return errors.New("a is not nil")
},
},
}
lst, err := jwk.ParseList([]byte(rows), nil)
if err != nil {
t.Fatal(err.Error())
}
for i, v := range data {
t.Run(fmt.Sprint(i), func(t *testing.T) {
tkn, err := jwt.Parse(v.t)
if err != nil {
t.Fatal(err.Error())
}
kid, err := tkn.Header.GetKeyId()
if err != nil {
t.Fatal(err.Error())
}
var alg algo.Algorithm
key := lst.SelectByKid(kid)
if key != nil {
alg, err = key.Algo()
if err != nil {
t.Fatal(err.Error())
}
}
if err := v.f(tkn, alg); err != nil {
t.Fatal(err.Error())
}
})
}
}
var compareTests = []struct {
a, b []byte
i int
}{
{[]byte(""), []byte(""), 0},
{[]byte("a"), []byte(""), 1},
{[]byte(""), []byte("a"), -1},
{[]byte("abc"), []byte("abc"), 0},
{[]byte("abd"), []byte("abc"), 1},
{[]byte("abc"), []byte("abd"), -1},
{[]byte("ab"), []byte("abc"), -1},
{[]byte("abc"), []byte("ab"), 1},
{[]byte("x"), []byte("ab"), 1},
{[]byte("ab"), []byte("x"), -1},
{[]byte("x"), []byte("a"), 1},
{[]byte("b"), []byte("x"), -1},
// test runtime·memeq's chunked implementation
{[]byte("abcdefgh"), []byte("abcdefgh"), 0},
{[]byte("abcdefghi"), []byte("abcdefghi"), 0},
{[]byte("abcdefghi"), []byte("abcdefghj"), -1},
{[]byte("abcdefghj"), []byte("abcdefghi"), 1},
{[]byte(" dahsdhajhdlj hadiuahdiuahsdupiu hh pauhd puhapduhapda"), []byte("abcdefghi"), 1},
{[]byte(" dahsdhajhdlj hadiuahdiuahsdupiu hh pauhd puhapduhapda"), []byte(" dahsdhajhdlj hadiuahdiuahsdupiu hh pauhd puhapduhapda"), 0},
// nil tests
{nil, nil, 0},
{[]byte(""), nil, 0},
{nil, []byte(""), 0},
{[]byte("a"), nil, 1},
{nil, []byte("a"), -1},
}
func BenchmarkEq(b *testing.B) {
var data = []func(a, b []byte) bool{
func(a, b []byte) bool {
return string(a) == string(b)
},
func(a, b []byte) bool {
if len(a) != len(b) {
return false
}
for i := 0; i < len(a); i++ {
if a[i] != b[i] {
return false
}
}
return true
},
func(a, b []byte) bool {
if len(a) != len(b) {
return false
}
for i, v := range a {
if b[i] != v {
return false
}
}
return true
},
}
for i, d := range data {
b.Run(fmt.Sprint(i), func(b *testing.B) {
for _, v := range compareTests {
if d(v.a, v.b) && v.i != 0 {
b.Fatal(string(v.a), " = ", string(v.b))
}
}
})
}
}

117
jwk/rsa.go Normal file
View File

@ -0,0 +1,117 @@
package jwk
import (
"crypto/rsa"
"encoding/base64"
"errors"
"math/big"
"git.daebt.dev/auth/algo"
"git.daebt.dev/auth/algo/rs"
)
type TokenRSA struct {
*Header
E string `json:"e"`
N string `json:"n"`
}
func (t *TokenRSA) encodeE(v int) {
var (
buf = []byte{}
skp = true
)
for i := 56; i > -1; i -= 8 {
b := byte(v >> i)
if skp && b == 0 {
continue
}
skp = false
buf = append(buf, b)
}
t.E = base64.RawURLEncoding.EncodeToString(buf)
}
func (t *TokenRSA) decodeE() (int, error) {
if t.E == "" {
return 0, errors.New("e is empty")
}
buf, err := base64.RawURLEncoding.DecodeString(t.E)
if err != nil {
return 0, err
}
var (
num int
l = len(buf) - 1
)
for i, b := range buf {
i = (l - i) * 8
num = num | int(b)<<i
i += 8
}
return num, nil
}
func (t *TokenRSA) Algo() (algo.Algorithm, error) {
n, err := base64.RawURLEncoding.DecodeString(t.N)
if err != nil {
return nil, err
}
e, err := t.decodeE()
if err != nil {
return nil, err
}
o := rs.WithPub(&rsa.PublicKey{
N: new(big.Int).SetBytes(n),
E: e,
})
switch t.Header.Algorithm {
case rs.AlgorithmRS256:
return rs.NewRS256(o)
case rs.AlgorithmRS384:
return rs.NewRS384(o)
case rs.AlgorithmRS512:
return rs.NewRS512(o)
}
return nil, errors.New("undefined algorithm type")
}
func (t *TokenRSA) KeyId() []byte {
if t.Header.KeyId != "" {
if val, err := base64.RawURLEncoding.DecodeString(t.Header.KeyId); err == nil {
return val
}
}
return nil
}
func NewRSA(kid []byte, v *rs.Algo) Token {
pub := v.PublicKey()
t := &TokenRSA{
Header: &Header{
KeyId: "",
KeyType: v.Key(),
Use: UseSignature,
Algorithm: v.Algo(),
},
E: "",
N: base64.RawURLEncoding.EncodeToString(pub.N.Bytes()),
}
if len(kid) > 0 {
t.Header.KeyId = base64.RawURLEncoding.EncodeToString(kid)
}
t.encodeE(pub.E)
return t
}