create: pkg jwt

This commit is contained in:
shchva 2025-01-11 00:42:32 +03:00
parent 1e5a63ba68
commit d3d2549829
8 changed files with 570 additions and 0 deletions

2
go.mod
View File

@ -1,3 +1,5 @@
module git.daebt.dev/auth
go 1.23.3
require github.com/gofrs/uuid v4.4.0+incompatible

2
go.sum Normal file
View File

@ -0,0 +1,2 @@
github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA=
github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM=

30
jwt/header.go Normal file
View File

@ -0,0 +1,30 @@
package jwt
import (
"encoding/base64"
"git.daebt.dev/auth/algo"
)
type Header struct {
*Map
}
func (h *Header) GetKeyId() ([]byte, error) {
var val string
if err := h.Unmarshal("kid", &val); err != nil {
return nil, err
}
return base64.RawURLEncoding.DecodeString(val)
}
func (h *Header) GetType() (string, error) {
var val string
return val, h.Unmarshal("typ", &val)
}
func (h *Header) GetAlgorithm() (algo.AlgorithmType, error) {
var val algo.AlgorithmType
return val, h.Unmarshal("alg", &val)
}

165
jwt/jwt.go Normal file
View File

@ -0,0 +1,165 @@
package jwt
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"git.daebt.dev/auth/algo"
)
var (
ErrKeyNotExist = errors.New("key does not exist")
ErrKeyEmpty = errors.New("key is empty")
ErrKeyInvalidType = errors.New("key is of invalid type")
ErrPayloadEmpty = errors.New("payload is empty")
ErrAlgorithmNil = errors.New("algorithm is nil")
ErrTokenMalformed = errors.New("token is malformed")
//
ErrNotJWTType = errors.New("token of not JWT type")
ErrAlgorithmMismatch = errors.New("token is signed by another algorithm")
)
type Token struct {
Header *Header
Payload *Payload
body []byte
sign []byte
}
func (t *Token) Sign(a algo.Algorithm) (string, error) {
if a == nil {
return "", ErrAlgorithmNil
}
t.Header.AppendArg("alg", a.Algo())
h, err := t.encodeSegment(t.Header.Map)
if err != nil {
return "", err
}
p, err := t.encodeSegment(t.Payload.Map)
if err != nil {
return "", err
}
t.body = []byte(h + "." + p)
t.sign, err = a.Sign(t.body)
if err != nil {
return "", err
}
return fmt.Sprintf("%s.%s", string(t.body),
base64.RawURLEncoding.EncodeToString(t.sign),
), nil
}
func (t *Token) encodeSegment(val any) (string, error) {
buf, err := json.Marshal(val)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
func (t *Token) decodeSegment(str string, val any) error {
buf, err := base64.RawURLEncoding.DecodeString(str)
if err != nil {
return err
}
return json.Unmarshal(buf, val)
}
func (t *Token) Verify(a algo.Algorithm) error {
if a == nil {
return ErrAlgorithmNil
}
if val, err := t.Header.GetAlgorithm(); err != nil {
return err
} else if val != a.Algo() {
return ErrAlgorithmMismatch
}
return a.Verify(t.body, t.sign)
}
func New(o ...Option) *Token {
t := &Token{
Header: &Header{
Map: &Map{
v: map[string]json.RawMessage{},
},
},
Payload: &Payload{
Map: &Map{
v: map[string]json.RawMessage{},
},
},
}
for _, f := range o {
f(t)
}
t.Header.AppendArg("typ", "JWT")
return t
}
func Parse(token string) (*Token, error) {
arr := strings.Split(token, ".")
if len(arr) != 3 {
return nil, ErrTokenMalformed
}
// decode signature
s, err := base64.RawURLEncoding.DecodeString(arr[2])
if err != nil {
return nil, err
}
hp := new(bytes.Buffer)
hp.WriteString(arr[0])
hp.WriteByte(0x2E) // .
hp.WriteString(arr[1])
defer hp.Reset()
t := &Token{
Header: &Header{
Map: &Map{},
},
Payload: &Payload{
Map: &Map{},
},
body: hp.Bytes(),
sign: s,
}
// decode header
if err := t.decodeSegment(arr[0], t.Header.Map); err != nil {
return t, err
}
if val, err := t.Header.GetType(); err != nil {
return t, err
} else if val != "JWT" {
return t, ErrNotJWTType
}
// decode payload
if err := t.decodeSegment(arr[1], t.Payload.Map); err != nil {
return t, err
}
return t, nil
}

135
jwt/jwt_test.go Normal file
View File

@ -0,0 +1,135 @@
package jwt_test
import (
"fmt"
"testing"
"time"
"git.daebt.dev/auth/algo/rs"
"git.daebt.dev/auth/jwt"
)
var key = `-----BEGIN RSA PRIVATE KEY-----
MIIEpAIBAAKCAQEAzIWvl1OtwExnQ3HvoYk6bFRIlcCjzdv1yHazJfr6jxk6w+tC
sIWEdtKsAPm3gmmFG+mTuHq+H53sahm6DD9YC5ZQjnvSYkBKv70Zw331/tg9VLbf
Jc+gN7kbD3xMQsucYD0973r7l9pEPH4Qw/I+BEKHMxlTmynStgKxnfyO6iPkL5jT
zpUXlD4V9xqUoMY/uX3EpGwbJJKFJuphYX3jzJQ++tovQGGep7RgNeMEoWjyAkJ2
yb7tiSWsw7qk6GO2z6NDmnc1UsdSBZ6Vg7BPUp8EINAdX1wbmB0+QH/vp0huM6lS
Y6NMBugwptQUe5UCaly9fN4kb26U0qglEoGB6wIDAQABAoIBADU244UgRKkwN/4Y
ex0ws37UPz6XrQc3IDBUkjBjqSXqjpvDbsq3Mswn7JEkaFcKVZP5pnHtneJkGMtS
flIJeUMqjTNFjGv8Bnb1IOr4rzTr1qlgG5ee+jUFeMECumT2zW1NAfx5p1TPecmz
k3EoanJ5TOxCvro0m5q4ALb2q8jHrtfvtqEBrHepeEp3Lyh7m4ZUib+0yWXs0EPC
HhF9kLpCy+tVXUPDyLCt4cldTUda/3xeswzmxVRrHkt/idsNuTAi7o1Cx/OfZYMI
AzQo8OTh1Bg2DlXKtOX40frIxy3/K77F3ozwV4a3FravUO+wvcQdIyi6KAmNBFS3
9IVA7IECgYEA4x0TAn7vOvazILmg4Es0/gsWlh7RGmNTqVYtj5TfwOviUGeaculR
BSsyX8pgaROJKjGDcNzSQQEhHfJXrhNXeMJ0zPUIsJCBmih8oaCpAScWpV+qpdky
1Eb2akEg7XbpqBJJ1jnoEvIhd4feCAN8Gv8vcmdER7HaGdyef4XxlFsCgYEA5okG
tbyTtD2cfmYjYsoGqEfGH0Pe9vxc+MBthiPg0f2lpg+YuSPx92ZuJciLNyNWo9qf
NFnzbSEFxzomK/Bgq9ujGnbPyLOCadIADM4/njEEPe+IsagDxBgTrCEUJ56W9MLj
N+b4d/gnBkK4roDW8gjy7x4MbePByoDfaWtU/bECgYEApn8RCZpe7V4gMdSEIQph
fgBI/aL37p10nsbDvegJJRiIoCNjsexj7iMd2eW2SjH9M4Z68smgBfG7AoZASyh4
ztnX4M2eIjq+GHKn86GhZGvwiSoaI12YitC/I2Q9rHipkQJfSQLIpOMHL+bWGg/b
8rqzYO5duyWiW6VGOPzL/tMCgYB3JVSZcrfnzHvn+8PIF9+u80FbAUnn3m/yhAlW
7Y4RGYWWOLNW5FP26DJ/RpFk0tfBYYksllywBwQkflIiHV7pE1/NmqAy+0uog0dR
VvscN/sYQ4cjQlGH9GWebY4sF9Ou9lZWmwHJhzAsFSm7zozIlIVxvdbwqGiMz2Qn
6LgJUQKBgQC9H2JGm54wg0YPuDig5LjymUxYJrEiJT0IXz4vy+UEMxw+1EmeD5sm
kSqHkwNDp7D+3nik5HzoFVifJAvqFWU73fpvqQlvZSNfVrtq8UvJBIuH7eHkrJrC
L8dEn16HWjLX50GlT+9eYyHWtYI4sMdnzz1/JS6PwQRxKlFQN9HJYg==
-----END RSA PRIVATE KEY-----`
func TestCreate(t *testing.T) {
alg, err := rs.NewRS256(
rs.WithPEM(nil, []byte(key)),
)
if err != nil {
t.Fatal(err.Error())
}
tm := time.Now()
var data = [][]jwt.Option{
{
jwt.WithHeaderKeyId([]byte("019451f1-f789-72a6-836d-3bc6146ad76a")),
jwt.WithIssuer("https://git.daebt.dev"),
jwt.WithAudience("https://0.example.com"),
jwt.WithSubject("example:0"),
},
{
jwt.WithHeaderKeyId([]byte("019451f1-f789-72a6-836d-3bc6146ad76a")),
jwt.WithIssuer("https://git.daebt.dev"),
jwt.WithAudience("https://1.example.com"),
jwt.WithSubject("example:1"),
jwt.WithIssuedAt(tm),
jwt.WithExpirationTime(tm.Add(time.Hour)),
},
{
jwt.WithHeaderKeyId([]byte("019451f1-f789-72a6-2911-3bc6146ad76a")),
jwt.WithIssuer("https://git.daebt.dev"),
jwt.WithAudience("https://2.example.com"),
jwt.WithSubject("example:2"),
jwt.WithIssuedAt(tm),
jwt.WithNotBefore(tm),
},
{
jwt.WithHeaderKeyId([]byte("019451f1-f789-72a6-2911-3bc6146ad76a")),
jwt.WithIssuer("https://git.daebt.dev"),
jwt.WithAudience("https://3.example.com"),
jwt.WithSubject("example:3"),
jwt.WithIssuedAt(tm),
jwt.WithNotBefore(tm.Add(time.Minute)),
},
{
jwt.WithHeaderKeyId([]byte("019451f1-f789-72a6-2911-3bc6146ad76a")),
jwt.WithIssuer("https://git.daebt.dev"),
jwt.WithAudience("https://4.example.com"),
jwt.WithSubject("example:4"),
jwt.WithIssuedAt(tm),
jwt.WithNotBefore(tm.Add(time.Minute)),
jwt.WithExpirationTime(tm.Add(time.Hour)),
},
}
for _, v := range data {
// t.Run(fmt.Sprint(i), func(t *testing.T) {
val, err := jwt.New(v...).Sign(alg)
if err != nil {
t.Fatal(err.Error())
}
// t.Log(val)
fmt.Printf(`"%s",`, val)
// })
}
}
// func TestJwt(t *testing.T) {
// alg, err := rs.NewRS256(
// rs.WithGenerateKey(2048),
// )
// if err != nil {
// t.Fatal(err.Error())
// }
// tkn := New(
// WithSubject("main-jwt-token"),
// )
// str, err := tkn.Sign(alg)
// if err != nil {
// t.Fatal(err.Error())
// }
// tkn, err = Parse(str)
// if err != nil {
// t.Fatal(err.Error())
// }
// if err := tkn.Verify(alg); err != nil {
// t.Fatal(err.Error())
// }
// tkn.Payload.Range(func(key string, val any) bool {
// t.Log(key, val)
// return true
// })
// }

70
jwt/map.go Normal file
View File

@ -0,0 +1,70 @@
package jwt
import "encoding/json"
type Map struct {
v map[string]json.RawMessage
}
func (m *Map) UnmarshalJSON(buf []byte) error {
if m.v == nil {
m.v = map[string]json.RawMessage{}
}
return json.Unmarshal(buf, &m.v)
}
func (m *Map) MarshalJSON() ([]byte, error) {
if m.v == nil || len(m.v) < 1 {
return nil, ErrPayloadEmpty
}
return json.Marshal(m.v)
}
// AppendArg добавляет аргумент
func (m *Map) AppendArg(key string, val any) error {
if key == "" {
return ErrKeyEmpty
}
var err error
m.v[key], err = json.Marshal(val)
return err
}
// AppendArgs добавляет аргументы
func (m *Map) AppendArgs(kv map[string]any) error {
var err error
for k, v := range kv {
if k == "" {
return ErrKeyEmpty
}
m.v[k], err = json.Marshal(v)
if err != nil {
return err
}
}
return nil
}
func (m *Map) Unmarshal(key string, val any) error {
buf, is := m.v[key]
if !is {
return ErrKeyNotExist
}
return json.Unmarshal(buf, val)
}
func (m *Map) Range(f func(key string, val any) bool) {
for k := range m.v {
var val any
m.Unmarshal(k, &val)
if !f(k, val) {
break
}
}
}

90
jwt/option.go Normal file
View File

@ -0,0 +1,90 @@
package jwt
import (
"encoding/base64"
"encoding/hex"
"time"
"github.com/gofrs/uuid"
)
type Option func(*Token)
// WithIssuer устанавливает идентификатор принципала, выдавшего JWT (string, URL)
func WithIssuer(iss string) Option {
return func(t *Token) {
t.Payload.AppendArg("iss", iss)
}
}
// WithSubject устанавливает идентификатор принципала, который является предметом JWT (string, URL)
func WithSubject(sub string) Option {
return func(t *Token) {
t.Payload.AppendArg("sub", sub)
}
}
// WithAudience устанавливает идентификатор получателей, для которых предназначен JWT
func WithAudience(aud ...string) Option {
return func(t *Token) {
if len(aud) == 1 {
t.Payload.AppendArg("aud", aud[0])
return
}
t.Payload.AppendArg("aud", aud)
}
}
// WithExpirationTime устанавливает время истечения срока действия, по истечении которого JWT НЕ ДОЛЖЕН быть принят к обработке
func WithExpirationTime(exp time.Time) Option {
return func(t *Token) {
t.Payload.AppendArg("exp", exp)
}
}
// WithNotBefore устанавливает время, до которого JWT НЕ ДОЛЖЕН быть принят к обработке
func WithNotBefore(nbf time.Time) Option {
return func(t *Token) {
t.Payload.AppendArg("nbf", nbf)
}
}
// WithIssuedAt устанавливает время, когда был создан JWT
func WithIssuedAt(iat time.Time) Option {
return func(t *Token) {
t.Payload.AppendArg("iat", iat)
}
}
// WithIssuedAtNow устанавливает текущее время создания JWT
func WithIssuedAtNow() Option {
return func(t *Token) {
t.Payload.AppendArg("iat", time.Now())
}
}
// WithIssuedAt устанавливает уникальный идентификатор для JWT
func WithJwtId(jti string) Option {
return func(t *Token) {
t.Payload.AppendArg("jti", jti)
}
}
// WithGenJwtId генерирует уникальный идентификатор для JWT (uuid v7)
func WithGenJwtId() Option {
return func(t *Token) {
jti := hex.EncodeToString([]byte(time.Now().String()))
if val, err := uuid.NewV7(); err == nil {
jti = val.String()
}
t.Payload.AppendArg("jti", jti)
}
}
// WithHeaderKeyId устанавливает в заголовке параметр kid
func WithHeaderKeyId(kid []byte) Option {
return func(t *Token) {
t.Header.AppendArg("kid", base64.RawURLEncoding.EncodeToString(kid))
}
}

76
jwt/payload.go Normal file
View File

@ -0,0 +1,76 @@
package jwt
import (
"encoding/json"
"math"
"time"
)
type Payload struct {
*Map
}
// GetExpirationTime возвращает срок действия JWT
func (p *Payload) GetExpirationTime() (time.Time, error) {
return p.GetTime("exp")
}
// GetIssuedAt возвращает время создания JWT
func (p *Payload) GetIssuedAt() (time.Time, error) {
return p.GetTime("iat")
}
// GetNotBefore возвращает время начала действия JWT
func (p *Payload) GetNotBefore() (time.Time, error) {
return p.GetTime("nbf")
}
// GetIssuer возвращает идентификатор принципала, выдавшего JWT
func (p *Payload) GetIssuer() (string, error) {
var val string
return val, p.Unmarshal("iss", &val)
}
// GetSubject возвращает идентификатор принципала, который является предметом JWT
func (p *Payload) GetSubject() (string, error) {
var val string
return val, p.Unmarshal("sub", &val)
}
// GetAudience возвращает идентификатор получателей
func (p *Payload) GetAudience() ([]string, error) {
var val any
if err := p.Unmarshal("aud", &val); err != nil {
return nil, err
}
switch v := val.(type) {
case string:
return []string{v}, nil
case []string:
return v, nil
}
return nil, ErrKeyInvalidType
}
func (p *Payload) GetTime(key string) (time.Time, error) {
var val json.Number
if err := p.Unmarshal(key, &val); err != nil {
return time.Time{}, err
}
num, err := val.Float64()
if err != nil {
return time.Time{}, err
}
round, frac := math.Modf(num)
return time.Unix(int64(round), int64(frac*1e9)), nil
}
// GetAny получает значение из полезной нагрузки
func (p *Payload) GetAny(key string) (any, error) {
var val any
return val, p.Unmarshal(key, &val)
}