diff --git a/go.mod b/go.mod index 4a5be7f..b727ae8 100644 --- a/go.mod +++ b/go.mod @@ -1,3 +1,5 @@ module git.daebt.dev/auth go 1.23.3 + +require github.com/gofrs/uuid v4.4.0+incompatible diff --git a/go.sum b/go.sum new file mode 100644 index 0000000..c0ad687 --- /dev/null +++ b/go.sum @@ -0,0 +1,2 @@ +github.com/gofrs/uuid v4.4.0+incompatible h1:3qXRTX8/NbyulANqlc0lchS1gqAVxRgsuW1YrTJupqA= +github.com/gofrs/uuid v4.4.0+incompatible/go.mod h1:b2aQJv3Z4Fp6yNu3cdSllBxTCLRxnplIgP/c0N/04lM= diff --git a/jwt/header.go b/jwt/header.go new file mode 100644 index 0000000..fe4e100 --- /dev/null +++ b/jwt/header.go @@ -0,0 +1,30 @@ +package jwt + +import ( + "encoding/base64" + + "git.daebt.dev/auth/algo" +) + +type Header struct { + *Map +} + +func (h *Header) GetKeyId() ([]byte, error) { + var val string + if err := h.Unmarshal("kid", &val); err != nil { + return nil, err + } + + return base64.RawURLEncoding.DecodeString(val) +} + +func (h *Header) GetType() (string, error) { + var val string + return val, h.Unmarshal("typ", &val) +} + +func (h *Header) GetAlgorithm() (algo.AlgorithmType, error) { + var val algo.AlgorithmType + return val, h.Unmarshal("alg", &val) +} diff --git a/jwt/jwt.go b/jwt/jwt.go new file mode 100644 index 0000000..22ab8f0 --- /dev/null +++ b/jwt/jwt.go @@ -0,0 +1,165 @@ +package jwt + +import ( + "bytes" + "encoding/base64" + "encoding/json" + "errors" + "fmt" + "strings" + + "git.daebt.dev/auth/algo" +) + +var ( + ErrKeyNotExist = errors.New("key does not exist") + ErrKeyEmpty = errors.New("key is empty") + ErrKeyInvalidType = errors.New("key is of invalid type") + + ErrPayloadEmpty = errors.New("payload is empty") + + ErrAlgorithmNil = errors.New("algorithm is nil") + + ErrTokenMalformed = errors.New("token is malformed") + // + ErrNotJWTType = errors.New("token of not JWT type") + ErrAlgorithmMismatch = errors.New("token is signed by another algorithm") +) + +type Token struct { + Header *Header + Payload *Payload + body []byte + sign []byte +} + +func (t *Token) Sign(a algo.Algorithm) (string, error) { + if a == nil { + return "", ErrAlgorithmNil + } + + t.Header.AppendArg("alg", a.Algo()) + + h, err := t.encodeSegment(t.Header.Map) + if err != nil { + return "", err + } + + p, err := t.encodeSegment(t.Payload.Map) + if err != nil { + return "", err + } + + t.body = []byte(h + "." + p) + + t.sign, err = a.Sign(t.body) + if err != nil { + return "", err + } + + return fmt.Sprintf("%s.%s", string(t.body), + base64.RawURLEncoding.EncodeToString(t.sign), + ), nil +} + +func (t *Token) encodeSegment(val any) (string, error) { + buf, err := json.Marshal(val) + if err != nil { + return "", err + } + + return base64.RawURLEncoding.EncodeToString(buf), nil +} + +func (t *Token) decodeSegment(str string, val any) error { + buf, err := base64.RawURLEncoding.DecodeString(str) + if err != nil { + return err + } + + return json.Unmarshal(buf, val) +} + +func (t *Token) Verify(a algo.Algorithm) error { + if a == nil { + return ErrAlgorithmNil + } + + if val, err := t.Header.GetAlgorithm(); err != nil { + return err + } else if val != a.Algo() { + return ErrAlgorithmMismatch + } + + return a.Verify(t.body, t.sign) +} + +func New(o ...Option) *Token { + t := &Token{ + Header: &Header{ + Map: &Map{ + v: map[string]json.RawMessage{}, + }, + }, + Payload: &Payload{ + Map: &Map{ + v: map[string]json.RawMessage{}, + }, + }, + } + for _, f := range o { + f(t) + } + + t.Header.AppendArg("typ", "JWT") + + return t +} + +func Parse(token string) (*Token, error) { + arr := strings.Split(token, ".") + if len(arr) != 3 { + return nil, ErrTokenMalformed + } + + // decode signature + s, err := base64.RawURLEncoding.DecodeString(arr[2]) + if err != nil { + return nil, err + } + + hp := new(bytes.Buffer) + hp.WriteString(arr[0]) + hp.WriteByte(0x2E) // . + hp.WriteString(arr[1]) + defer hp.Reset() + + t := &Token{ + Header: &Header{ + Map: &Map{}, + }, + Payload: &Payload{ + Map: &Map{}, + }, + body: hp.Bytes(), + sign: s, + } + + // decode header + if err := t.decodeSegment(arr[0], t.Header.Map); err != nil { + return t, err + } + + if val, err := t.Header.GetType(); err != nil { + return t, err + } else if val != "JWT" { + return t, ErrNotJWTType + } + + // decode payload + if err := t.decodeSegment(arr[1], t.Payload.Map); err != nil { + return t, err + } + + return t, nil +} diff --git a/jwt/jwt_test.go b/jwt/jwt_test.go new file mode 100644 index 0000000..35ccfac --- /dev/null +++ b/jwt/jwt_test.go @@ -0,0 +1,135 @@ +package jwt_test + +import ( + "fmt" + "testing" + "time" + + "git.daebt.dev/auth/algo/rs" + "git.daebt.dev/auth/jwt" +) + +var key = `-----BEGIN RSA PRIVATE KEY----- +MIIEpAIBAAKCAQEAzIWvl1OtwExnQ3HvoYk6bFRIlcCjzdv1yHazJfr6jxk6w+tC +sIWEdtKsAPm3gmmFG+mTuHq+H53sahm6DD9YC5ZQjnvSYkBKv70Zw331/tg9VLbf +Jc+gN7kbD3xMQsucYD0973r7l9pEPH4Qw/I+BEKHMxlTmynStgKxnfyO6iPkL5jT +zpUXlD4V9xqUoMY/uX3EpGwbJJKFJuphYX3jzJQ++tovQGGep7RgNeMEoWjyAkJ2 +yb7tiSWsw7qk6GO2z6NDmnc1UsdSBZ6Vg7BPUp8EINAdX1wbmB0+QH/vp0huM6lS +Y6NMBugwptQUe5UCaly9fN4kb26U0qglEoGB6wIDAQABAoIBADU244UgRKkwN/4Y +ex0ws37UPz6XrQc3IDBUkjBjqSXqjpvDbsq3Mswn7JEkaFcKVZP5pnHtneJkGMtS +flIJeUMqjTNFjGv8Bnb1IOr4rzTr1qlgG5ee+jUFeMECumT2zW1NAfx5p1TPecmz +k3EoanJ5TOxCvro0m5q4ALb2q8jHrtfvtqEBrHepeEp3Lyh7m4ZUib+0yWXs0EPC +HhF9kLpCy+tVXUPDyLCt4cldTUda/3xeswzmxVRrHkt/idsNuTAi7o1Cx/OfZYMI +AzQo8OTh1Bg2DlXKtOX40frIxy3/K77F3ozwV4a3FravUO+wvcQdIyi6KAmNBFS3 +9IVA7IECgYEA4x0TAn7vOvazILmg4Es0/gsWlh7RGmNTqVYtj5TfwOviUGeaculR +BSsyX8pgaROJKjGDcNzSQQEhHfJXrhNXeMJ0zPUIsJCBmih8oaCpAScWpV+qpdky +1Eb2akEg7XbpqBJJ1jnoEvIhd4feCAN8Gv8vcmdER7HaGdyef4XxlFsCgYEA5okG +tbyTtD2cfmYjYsoGqEfGH0Pe9vxc+MBthiPg0f2lpg+YuSPx92ZuJciLNyNWo9qf +NFnzbSEFxzomK/Bgq9ujGnbPyLOCadIADM4/njEEPe+IsagDxBgTrCEUJ56W9MLj +N+b4d/gnBkK4roDW8gjy7x4MbePByoDfaWtU/bECgYEApn8RCZpe7V4gMdSEIQph +fgBI/aL37p10nsbDvegJJRiIoCNjsexj7iMd2eW2SjH9M4Z68smgBfG7AoZASyh4 +ztnX4M2eIjq+GHKn86GhZGvwiSoaI12YitC/I2Q9rHipkQJfSQLIpOMHL+bWGg/b +8rqzYO5duyWiW6VGOPzL/tMCgYB3JVSZcrfnzHvn+8PIF9+u80FbAUnn3m/yhAlW +7Y4RGYWWOLNW5FP26DJ/RpFk0tfBYYksllywBwQkflIiHV7pE1/NmqAy+0uog0dR +VvscN/sYQ4cjQlGH9GWebY4sF9Ou9lZWmwHJhzAsFSm7zozIlIVxvdbwqGiMz2Qn +6LgJUQKBgQC9H2JGm54wg0YPuDig5LjymUxYJrEiJT0IXz4vy+UEMxw+1EmeD5sm +kSqHkwNDp7D+3nik5HzoFVifJAvqFWU73fpvqQlvZSNfVrtq8UvJBIuH7eHkrJrC +L8dEn16HWjLX50GlT+9eYyHWtYI4sMdnzz1/JS6PwQRxKlFQN9HJYg== +-----END RSA PRIVATE KEY-----` + +func TestCreate(t *testing.T) { + alg, err := rs.NewRS256( + rs.WithPEM(nil, []byte(key)), + ) + if err != nil { + t.Fatal(err.Error()) + } + + tm := time.Now() + + var data = [][]jwt.Option{ + { + jwt.WithHeaderKeyId([]byte("019451f1-f789-72a6-836d-3bc6146ad76a")), + jwt.WithIssuer("https://git.daebt.dev"), + jwt.WithAudience("https://0.example.com"), + jwt.WithSubject("example:0"), + }, + { + jwt.WithHeaderKeyId([]byte("019451f1-f789-72a6-836d-3bc6146ad76a")), + jwt.WithIssuer("https://git.daebt.dev"), + jwt.WithAudience("https://1.example.com"), + jwt.WithSubject("example:1"), + jwt.WithIssuedAt(tm), + jwt.WithExpirationTime(tm.Add(time.Hour)), + }, + { + jwt.WithHeaderKeyId([]byte("019451f1-f789-72a6-2911-3bc6146ad76a")), + jwt.WithIssuer("https://git.daebt.dev"), + jwt.WithAudience("https://2.example.com"), + jwt.WithSubject("example:2"), + jwt.WithIssuedAt(tm), + jwt.WithNotBefore(tm), + }, + { + jwt.WithHeaderKeyId([]byte("019451f1-f789-72a6-2911-3bc6146ad76a")), + jwt.WithIssuer("https://git.daebt.dev"), + jwt.WithAudience("https://3.example.com"), + jwt.WithSubject("example:3"), + jwt.WithIssuedAt(tm), + jwt.WithNotBefore(tm.Add(time.Minute)), + }, + { + jwt.WithHeaderKeyId([]byte("019451f1-f789-72a6-2911-3bc6146ad76a")), + jwt.WithIssuer("https://git.daebt.dev"), + jwt.WithAudience("https://4.example.com"), + jwt.WithSubject("example:4"), + jwt.WithIssuedAt(tm), + jwt.WithNotBefore(tm.Add(time.Minute)), + jwt.WithExpirationTime(tm.Add(time.Hour)), + }, + } + + for _, v := range data { + // t.Run(fmt.Sprint(i), func(t *testing.T) { + val, err := jwt.New(v...).Sign(alg) + if err != nil { + t.Fatal(err.Error()) + } + // t.Log(val) + fmt.Printf(`"%s",`, val) + // }) + } + +} + +// func TestJwt(t *testing.T) { +// alg, err := rs.NewRS256( +// rs.WithGenerateKey(2048), +// ) +// if err != nil { +// t.Fatal(err.Error()) +// } + +// tkn := New( +// WithSubject("main-jwt-token"), +// ) + +// str, err := tkn.Sign(alg) +// if err != nil { +// t.Fatal(err.Error()) +// } + +// tkn, err = Parse(str) +// if err != nil { +// t.Fatal(err.Error()) +// } + +// if err := tkn.Verify(alg); err != nil { +// t.Fatal(err.Error()) +// } + +// tkn.Payload.Range(func(key string, val any) bool { +// t.Log(key, val) +// return true +// }) +// } diff --git a/jwt/map.go b/jwt/map.go new file mode 100644 index 0000000..fa7ccf0 --- /dev/null +++ b/jwt/map.go @@ -0,0 +1,70 @@ +package jwt + +import "encoding/json" + +type Map struct { + v map[string]json.RawMessage +} + +func (m *Map) UnmarshalJSON(buf []byte) error { + if m.v == nil { + m.v = map[string]json.RawMessage{} + } + return json.Unmarshal(buf, &m.v) +} + +func (m *Map) MarshalJSON() ([]byte, error) { + if m.v == nil || len(m.v) < 1 { + return nil, ErrPayloadEmpty + } + + return json.Marshal(m.v) +} + +// AppendArg добавляет аргумент +func (m *Map) AppendArg(key string, val any) error { + if key == "" { + return ErrKeyEmpty + } + + var err error + m.v[key], err = json.Marshal(val) + return err +} + +// AppendArgs добавляет аргументы +func (m *Map) AppendArgs(kv map[string]any) error { + var err error + + for k, v := range kv { + if k == "" { + return ErrKeyEmpty + } + + m.v[k], err = json.Marshal(v) + if err != nil { + return err + } + } + + return nil +} + +func (m *Map) Unmarshal(key string, val any) error { + buf, is := m.v[key] + if !is { + return ErrKeyNotExist + } + + return json.Unmarshal(buf, val) +} + +func (m *Map) Range(f func(key string, val any) bool) { + for k := range m.v { + var val any + m.Unmarshal(k, &val) + if !f(k, val) { + break + } + } +} diff --git a/jwt/option.go b/jwt/option.go new file mode 100644 index 0000000..9a61413 --- /dev/null +++ b/jwt/option.go @@ -0,0 +1,90 @@ +package jwt + +import ( + "encoding/base64" + "encoding/hex" + "time" + + "github.com/gofrs/uuid" +) + +type Option func(*Token) + +// WithIssuer устанавливает идентификатор принципала, выдавшего JWT (string, URL) +func WithIssuer(iss string) Option { + return func(t *Token) { + t.Payload.AppendArg("iss", iss) + } +} + +// WithSubject устанавливает идентификатор принципала, который является предметом JWT (string, URL) +func WithSubject(sub string) Option { + return func(t *Token) { + t.Payload.AppendArg("sub", sub) + } +} + +// WithAudience устанавливает идентификатор получателей, для которых предназначен JWT +func WithAudience(aud ...string) Option { + return func(t *Token) { + if len(aud) == 1 { + t.Payload.AppendArg("aud", aud[0]) + return + } + + t.Payload.AppendArg("aud", aud) + } +} + +// WithExpirationTime устанавливает время истечения срока действия, по истечении которого JWT НЕ ДОЛЖЕН быть принят к обработке +func WithExpirationTime(exp time.Time) Option { + return func(t *Token) { + t.Payload.AppendArg("exp", exp) + } +} + +// WithNotBefore устанавливает время, до которого JWT НЕ ДОЛЖЕН быть принят к обработке +func WithNotBefore(nbf time.Time) Option { + return func(t *Token) { + t.Payload.AppendArg("nbf", nbf) + } +} + +// WithIssuedAt устанавливает время, когда был создан JWT +func WithIssuedAt(iat time.Time) Option { + return func(t *Token) { + t.Payload.AppendArg("iat", iat) + } +} + +// WithIssuedAtNow устанавливает текущее время создания JWT +func WithIssuedAtNow() Option { + return func(t *Token) { + t.Payload.AppendArg("iat", time.Now()) + } +} + +// WithIssuedAt устанавливает уникальный идентификатор для JWT +func WithJwtId(jti string) Option { + return func(t *Token) { + t.Payload.AppendArg("jti", jti) + } +} + +// WithGenJwtId генерирует уникальный идентификатор для JWT (uuid v7) +func WithGenJwtId() Option { + return func(t *Token) { + jti := hex.EncodeToString([]byte(time.Now().String())) + if val, err := uuid.NewV7(); err == nil { + jti = val.String() + } + t.Payload.AppendArg("jti", jti) + } +} + +// WithHeaderKeyId устанавливает в заголовке параметр kid +func WithHeaderKeyId(kid []byte) Option { + return func(t *Token) { + t.Header.AppendArg("kid", base64.RawURLEncoding.EncodeToString(kid)) + } +} diff --git a/jwt/payload.go b/jwt/payload.go new file mode 100644 index 0000000..d6fa09e --- /dev/null +++ b/jwt/payload.go @@ -0,0 +1,76 @@ +package jwt + +import ( + "encoding/json" + "math" + "time" +) + +type Payload struct { + *Map +} + +// GetExpirationTime возвращает срок действия JWT +func (p *Payload) GetExpirationTime() (time.Time, error) { + return p.GetTime("exp") +} + +// GetIssuedAt возвращает время создания JWT +func (p *Payload) GetIssuedAt() (time.Time, error) { + return p.GetTime("iat") +} + +// GetNotBefore возвращает время начала действия JWT +func (p *Payload) GetNotBefore() (time.Time, error) { + return p.GetTime("nbf") +} + +// GetIssuer возвращает идентификатор принципала, выдавшего JWT +func (p *Payload) GetIssuer() (string, error) { + var val string + return val, p.Unmarshal("iss", &val) +} + +// GetSubject возвращает идентификатор принципала, который является предметом JWT +func (p *Payload) GetSubject() (string, error) { + var val string + return val, p.Unmarshal("sub", &val) +} + +// GetAudience возвращает идентификатор получателей +func (p *Payload) GetAudience() ([]string, error) { + var val any + if err := p.Unmarshal("aud", &val); err != nil { + return nil, err + } + + switch v := val.(type) { + case string: + return []string{v}, nil + case []string: + return v, nil + } + + return nil, ErrKeyInvalidType +} + +func (p *Payload) GetTime(key string) (time.Time, error) { + var val json.Number + if err := p.Unmarshal(key, &val); err != nil { + return time.Time{}, err + } + + num, err := val.Float64() + if err != nil { + return time.Time{}, err + } + + round, frac := math.Modf(num) + return time.Unix(int64(round), int64(frac*1e9)), nil +} + +// GetAny получает значение из полезной нагрузки +func (p *Payload) GetAny(key string) (any, error) { + var val any + return val, p.Unmarshal(key, &val) +}