194 lines
3.7 KiB
Go
194 lines
3.7 KiB
Go
package jwt
|
|
|
|
import (
|
|
"bytes"
|
|
"encoding/base64"
|
|
"encoding/json"
|
|
"errors"
|
|
"fmt"
|
|
"strings"
|
|
"time"
|
|
|
|
"git.daebt.dev/golang/auth/algo"
|
|
)
|
|
|
|
var (
|
|
ErrKeyNotExist = errors.New("key does not exist")
|
|
ErrKeyEmpty = errors.New("key is empty")
|
|
ErrKeyInvalidType = errors.New("key is of invalid type")
|
|
|
|
ErrPayloadEmpty = errors.New("payload is empty")
|
|
|
|
ErrAlgorithmNil = errors.New("algorithm is nil")
|
|
|
|
ErrTokenMalformed = errors.New("token is malformed")
|
|
//
|
|
ErrNotJWTType = errors.New("token of not JWT type")
|
|
ErrAlgorithmMismatch = errors.New("token is signed by another algorithm")
|
|
//
|
|
Err_1 = errors.New("заданного владельца(ов) не существует")
|
|
Err_2 = errors.New("заданного aud не существует")
|
|
ErrExpired = errors.New("token expired")
|
|
ErrNotActivated = errors.New("token yet not activated")
|
|
)
|
|
|
|
type Token struct {
|
|
Header *Header
|
|
Payload *Payload
|
|
body []byte
|
|
sign []byte
|
|
}
|
|
|
|
func (t *Token) Sign(a algo.Algorithm) (string, error) {
|
|
if a == nil {
|
|
return "", ErrAlgorithmNil
|
|
}
|
|
|
|
t.Header.AppendArg("alg", a.Algo())
|
|
|
|
h, err := t.encodeSegment(t.Header.Map)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
p, err := t.encodeSegment(t.Payload.Map)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
t.body = []byte(h + "." + p)
|
|
|
|
t.sign, err = a.Sign(t.body)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return fmt.Sprintf("%s.%s", string(t.body),
|
|
base64.RawURLEncoding.EncodeToString(t.sign),
|
|
), nil
|
|
}
|
|
|
|
func (t *Token) encodeSegment(val any) (string, error) {
|
|
buf, err := json.Marshal(val)
|
|
if err != nil {
|
|
return "", err
|
|
}
|
|
|
|
return base64.RawURLEncoding.EncodeToString(buf), nil
|
|
}
|
|
|
|
func (t *Token) decodeSegment(str string, val any) error {
|
|
buf, err := base64.RawURLEncoding.DecodeString(str)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
|
|
return json.Unmarshal(buf, val)
|
|
}
|
|
|
|
func (t *Token) Verify(a algo.Algorithm, o ...VerifyOption) error {
|
|
if a == nil {
|
|
return ErrAlgorithmNil
|
|
}
|
|
|
|
if val, err := t.Header.GetAlgorithm(); err != nil {
|
|
return err
|
|
} else if val != a.Algo() {
|
|
return ErrAlgorithmMismatch
|
|
}
|
|
|
|
if val, err := t.Payload.GetExpirationTime(); err != nil && !errors.Is(err, ErrKeyNotExist) {
|
|
return err
|
|
} else if !errors.Is(err, ErrKeyNotExist) {
|
|
if time.Now().After(val) {
|
|
return ErrExpired
|
|
}
|
|
}
|
|
|
|
if val, err := t.Payload.GetNotBefore(); err != nil && !errors.Is(err, ErrKeyNotExist) {
|
|
return err
|
|
} else if !errors.Is(err, ErrKeyNotExist) {
|
|
if time.Now().Before(val) {
|
|
return ErrNotActivated
|
|
}
|
|
}
|
|
|
|
for _, f := range o {
|
|
if err := f(t); err != nil {
|
|
return err
|
|
}
|
|
}
|
|
|
|
return a.Verify(t.body, t.sign)
|
|
}
|
|
|
|
func New(o ...Option) *Token {
|
|
t := &Token{
|
|
Header: &Header{
|
|
Map: &Map{
|
|
v: map[string]json.RawMessage{},
|
|
},
|
|
},
|
|
Payload: &Payload{
|
|
Map: &Map{
|
|
v: map[string]json.RawMessage{},
|
|
},
|
|
},
|
|
}
|
|
for _, f := range o {
|
|
f(t)
|
|
}
|
|
|
|
t.Header.AppendArg("typ", "JWT")
|
|
|
|
return t
|
|
}
|
|
|
|
func Parse(token string) (*Token, error) {
|
|
arr := strings.Split(token, ".")
|
|
if len(arr) != 3 {
|
|
return nil, ErrTokenMalformed
|
|
}
|
|
|
|
// decode signature
|
|
s, err := base64.RawURLEncoding.DecodeString(arr[2])
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
hp := new(bytes.Buffer)
|
|
hp.WriteString(arr[0])
|
|
hp.WriteByte(0x2E) // .
|
|
hp.WriteString(arr[1])
|
|
defer hp.Reset()
|
|
|
|
t := &Token{
|
|
Header: &Header{
|
|
Map: &Map{},
|
|
},
|
|
Payload: &Payload{
|
|
Map: &Map{},
|
|
},
|
|
body: hp.Bytes(),
|
|
sign: s,
|
|
}
|
|
|
|
// decode header
|
|
if err := t.decodeSegment(arr[0], t.Header.Map); err != nil {
|
|
return t, err
|
|
}
|
|
|
|
if val, err := t.Header.GetType(); err != nil {
|
|
return t, err
|
|
} else if val != "JWT" {
|
|
return t, ErrNotJWTType
|
|
}
|
|
|
|
// decode payload
|
|
if err := t.decodeSegment(arr[1], t.Payload.Map); err != nil {
|
|
return t, err
|
|
}
|
|
|
|
return t, nil
|
|
}
|