package jwt import ( "bytes" "encoding/base64" "encoding/json" "errors" "fmt" "strings" "time" "git.daebt.dev/auth/algo" ) var ( ErrKeyNotExist = errors.New("key does not exist") ErrKeyEmpty = errors.New("key is empty") ErrKeyInvalidType = errors.New("key is of invalid type") ErrPayloadEmpty = errors.New("payload is empty") ErrAlgorithmNil = errors.New("algorithm is nil") ErrTokenMalformed = errors.New("token is malformed") // ErrNotJWTType = errors.New("token of not JWT type") ErrAlgorithmMismatch = errors.New("token is signed by another algorithm") // Err_1 = errors.New("заданного владельца(ов) не существует") Err_2 = errors.New("заданного aud не существует") ErrExpired = errors.New("token expired") ErrNotActivated = errors.New("token yet not activated") ) type Token struct { Header *Header Payload *Payload body []byte sign []byte } func (t *Token) Sign(a algo.Algorithm) (string, error) { if a == nil { return "", ErrAlgorithmNil } t.Header.AppendArg("alg", a.Algo()) h, err := t.encodeSegment(t.Header.Map) if err != nil { return "", err } p, err := t.encodeSegment(t.Payload.Map) if err != nil { return "", err } t.body = []byte(h + "." + p) t.sign, err = a.Sign(t.body) if err != nil { return "", err } return fmt.Sprintf("%s.%s", string(t.body), base64.RawURLEncoding.EncodeToString(t.sign), ), nil } func (t *Token) encodeSegment(val any) (string, error) { buf, err := json.Marshal(val) if err != nil { return "", err } return base64.RawURLEncoding.EncodeToString(buf), nil } func (t *Token) decodeSegment(str string, val any) error { buf, err := base64.RawURLEncoding.DecodeString(str) if err != nil { return err } return json.Unmarshal(buf, val) } func (t *Token) Verify(a algo.Algorithm, o ...VerifyOption) error { if a == nil { return ErrAlgorithmNil } if val, err := t.Header.GetAlgorithm(); err != nil { return err } else if val != a.Algo() { return ErrAlgorithmMismatch } if val, err := t.Payload.GetExpirationTime(); err != nil && !errors.Is(err, ErrKeyNotExist) { return err } else if !errors.Is(err, ErrKeyNotExist) { if time.Now().After(val) { return ErrExpired } } if val, err := t.Payload.GetNotBefore(); err != nil && !errors.Is(err, ErrKeyNotExist) { return err } else if !errors.Is(err, ErrKeyNotExist) { if time.Now().Before(val) { return ErrNotActivated } } for _, f := range o { if err := f(t); err != nil { return err } } return a.Verify(t.body, t.sign) } func New(o ...Option) *Token { t := &Token{ Header: &Header{ Map: &Map{ v: map[string]json.RawMessage{}, }, }, Payload: &Payload{ Map: &Map{ v: map[string]json.RawMessage{}, }, }, } for _, f := range o { f(t) } t.Header.AppendArg("typ", "JWT") return t } func Parse(token string) (*Token, error) { arr := strings.Split(token, ".") if len(arr) != 3 { return nil, ErrTokenMalformed } // decode signature s, err := base64.RawURLEncoding.DecodeString(arr[2]) if err != nil { return nil, err } hp := new(bytes.Buffer) hp.WriteString(arr[0]) hp.WriteByte(0x2E) // . hp.WriteString(arr[1]) defer hp.Reset() t := &Token{ Header: &Header{ Map: &Map{}, }, Payload: &Payload{ Map: &Map{}, }, body: hp.Bytes(), sign: s, } // decode header if err := t.decodeSegment(arr[0], t.Header.Map); err != nil { return t, err } if val, err := t.Header.GetType(); err != nil { return t, err } else if val != "JWT" { return t, ErrNotJWTType } // decode payload if err := t.decodeSegment(arr[1], t.Payload.Map); err != nil { return t, err } return t, nil }