auth/jwt/jwt.go
2025-01-11 00:42:32 +03:00

166 lines
2.9 KiB
Go

package jwt
import (
"bytes"
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"strings"
"git.daebt.dev/auth/algo"
)
var (
ErrKeyNotExist = errors.New("key does not exist")
ErrKeyEmpty = errors.New("key is empty")
ErrKeyInvalidType = errors.New("key is of invalid type")
ErrPayloadEmpty = errors.New("payload is empty")
ErrAlgorithmNil = errors.New("algorithm is nil")
ErrTokenMalformed = errors.New("token is malformed")
//
ErrNotJWTType = errors.New("token of not JWT type")
ErrAlgorithmMismatch = errors.New("token is signed by another algorithm")
)
type Token struct {
Header *Header
Payload *Payload
body []byte
sign []byte
}
func (t *Token) Sign(a algo.Algorithm) (string, error) {
if a == nil {
return "", ErrAlgorithmNil
}
t.Header.AppendArg("alg", a.Algo())
h, err := t.encodeSegment(t.Header.Map)
if err != nil {
return "", err
}
p, err := t.encodeSegment(t.Payload.Map)
if err != nil {
return "", err
}
t.body = []byte(h + "." + p)
t.sign, err = a.Sign(t.body)
if err != nil {
return "", err
}
return fmt.Sprintf("%s.%s", string(t.body),
base64.RawURLEncoding.EncodeToString(t.sign),
), nil
}
func (t *Token) encodeSegment(val any) (string, error) {
buf, err := json.Marshal(val)
if err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(buf), nil
}
func (t *Token) decodeSegment(str string, val any) error {
buf, err := base64.RawURLEncoding.DecodeString(str)
if err != nil {
return err
}
return json.Unmarshal(buf, val)
}
func (t *Token) Verify(a algo.Algorithm) error {
if a == nil {
return ErrAlgorithmNil
}
if val, err := t.Header.GetAlgorithm(); err != nil {
return err
} else if val != a.Algo() {
return ErrAlgorithmMismatch
}
return a.Verify(t.body, t.sign)
}
func New(o ...Option) *Token {
t := &Token{
Header: &Header{
Map: &Map{
v: map[string]json.RawMessage{},
},
},
Payload: &Payload{
Map: &Map{
v: map[string]json.RawMessage{},
},
},
}
for _, f := range o {
f(t)
}
t.Header.AppendArg("typ", "JWT")
return t
}
func Parse(token string) (*Token, error) {
arr := strings.Split(token, ".")
if len(arr) != 3 {
return nil, ErrTokenMalformed
}
// decode signature
s, err := base64.RawURLEncoding.DecodeString(arr[2])
if err != nil {
return nil, err
}
hp := new(bytes.Buffer)
hp.WriteString(arr[0])
hp.WriteByte(0x2E) // .
hp.WriteString(arr[1])
defer hp.Reset()
t := &Token{
Header: &Header{
Map: &Map{},
},
Payload: &Payload{
Map: &Map{},
},
body: hp.Bytes(),
sign: s,
}
// decode header
if err := t.decodeSegment(arr[0], t.Header.Map); err != nil {
return t, err
}
if val, err := t.Header.GetType(); err != nil {
return t, err
} else if val != "JWT" {
return t, ErrNotJWTType
}
// decode payload
if err := t.decodeSegment(arr[1], t.Payload.Map); err != nil {
return t, err
}
return t, nil
}