You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

305 lines
8.4 KiB

package cos
import (
"crypto/hmac"
"crypto/sha1"
"fmt"
"hash"
"net/http"
"net/url"
"sort"
"strings"
"sync"
"time"
)
const sha1SignAlgorithm = "sha1"
const privateHeaderPrefix = "x-cos-"
const defaultAuthExpire = time.Hour
// 需要校验的 Headers 列表
var needSignHeaders = map[string]bool{
"host": true,
"range": true,
"x-cos-acl": true,
"x-cos-grant-read": true,
"x-cos-grant-write": true,
"x-cos-grant-full-control": true,
"response-content-type": true,
"response-content-language": true,
"response-expires": true,
"response-cache-control": true,
"response-content-disposition": true,
"response-content-encoding": true,
"cache-control": true,
"content-disposition": true,
"content-encoding": true,
"content-type": true,
"content-length": true,
"content-md5": true,
"expect": true,
"expires": true,
"x-cos-content-sha1": true,
"x-cos-storage-class": true,
"if-modified-since": true,
"origin": true,
"access-control-request-method": true,
"access-control-request-headers": true,
"x-cos-object-type": true,
}
func safeURLEncode(s string) string {
s = encodeURIComponent(s)
s = strings.Replace(s, "!", "%21", -1)
s = strings.Replace(s, "'", "%27", -1)
s = strings.Replace(s, "(", "%28", -1)
s = strings.Replace(s, ")", "%29", -1)
s = strings.Replace(s, "*", "%2A", -1)
return s
}
type valuesSignMap map[string][]string
func (vs valuesSignMap) Add(key, value string) {
key = strings.ToLower(key)
vs[key] = append(vs[key], value)
}
func (vs valuesSignMap) Encode() string {
var keys []string
for k := range vs {
keys = append(keys, k)
}
sort.Strings(keys)
var pairs []string
for _, k := range keys {
items := vs[k]
sort.Strings(items)
for _, val := range items {
pairs = append(
pairs,
fmt.Sprintf("%s=%s", safeURLEncode(k), safeURLEncode(val)))
}
}
return strings.Join(pairs, "&")
}
// AuthTime 用于生成签名所需的 q-sign-time 和 q-key-time 相关参数
type AuthTime struct {
SignStartTime time.Time
SignEndTime time.Time
KeyStartTime time.Time
KeyEndTime time.Time
}
// NewAuthTime 生成 AuthTime 的便捷函数
//
// expire: 从现在开始多久过期.
func NewAuthTime(expire time.Duration) *AuthTime {
signStartTime := time.Now()
keyStartTime := signStartTime
signEndTime := signStartTime.Add(expire)
keyEndTime := signEndTime
return &AuthTime{
SignStartTime: signStartTime,
SignEndTime: signEndTime,
KeyStartTime: keyStartTime,
KeyEndTime: keyEndTime,
}
}
// signString return q-sign-time string
func (a *AuthTime) signString() string {
return fmt.Sprintf("%d;%d", a.SignStartTime.Unix(), a.SignEndTime.Unix())
}
// keyString return q-key-time string
func (a *AuthTime) keyString() string {
return fmt.Sprintf("%d;%d", a.KeyStartTime.Unix(), a.KeyEndTime.Unix())
}
// newAuthorization 通过一系列步骤生成最终需要的 Authorization 字符串
func newAuthorization(secretID, secretKey string, req *http.Request, authTime *AuthTime) string {
signTime := authTime.signString()
keyTime := authTime.keyString()
signKey := calSignKey(secretKey, keyTime)
formatHeaders := *new(string)
signedHeaderList := *new([]string)
formatHeaders, signedHeaderList = genFormatHeaders(req.Header)
formatParameters, signedParameterList := genFormatParameters(req.URL.Query())
formatString := genFormatString(req.Method, *req.URL, formatParameters, formatHeaders)
stringToSign := calStringToSign(sha1SignAlgorithm, keyTime, formatString)
signature := calSignature(signKey, stringToSign)
return genAuthorization(
secretID, signTime, keyTime, signature, signedHeaderList,
signedParameterList,
)
}
// AddAuthorizationHeader 给 req 增加签名信息
func AddAuthorizationHeader(secretID, secretKey string, sessionToken string, req *http.Request, authTime *AuthTime) {
if secretID == "" {
return
}
auth := newAuthorization(secretID, secretKey, req,
authTime,
)
if len(sessionToken) > 0 {
req.Header.Set("x-cos-security-token", sessionToken)
}
req.Header.Set("Authorization", auth)
}
// calSignKey 计算 SignKey
func calSignKey(secretKey, keyTime string) string {
digest := calHMACDigest(secretKey, keyTime, sha1SignAlgorithm)
return fmt.Sprintf("%x", digest)
}
// calStringToSign 计算 StringToSign
func calStringToSign(signAlgorithm, signTime, formatString string) string {
h := sha1.New()
h.Write([]byte(formatString))
return fmt.Sprintf("%s\n%s\n%x\n", signAlgorithm, signTime, h.Sum(nil))
}
// calSignature 计算 Signature
func calSignature(signKey, stringToSign string) string {
digest := calHMACDigest(signKey, stringToSign, sha1SignAlgorithm)
return fmt.Sprintf("%x", digest)
}
// genAuthorization 生成 Authorization
func genAuthorization(secretID, signTime, keyTime, signature string, signedHeaderList, signedParameterList []string) string {
return strings.Join([]string{
"q-sign-algorithm=" + sha1SignAlgorithm,
"q-ak=" + secretID,
"q-sign-time=" + signTime,
"q-key-time=" + keyTime,
"q-header-list=" + strings.Join(signedHeaderList, ";"),
"q-url-param-list=" + strings.Join(signedParameterList, ";"),
"q-signature=" + signature,
}, "&")
}
// genFormatString 生成 FormatString
func genFormatString(method string, uri url.URL, formatParameters, formatHeaders string) string {
formatMethod := strings.ToLower(method)
formatURI := uri.Path
return fmt.Sprintf("%s\n%s\n%s\n%s\n", formatMethod, formatURI,
formatParameters, formatHeaders,
)
}
// genFormatParameters 生成 FormatParameters 和 SignedParameterList
// instead of the url.Values{}
func genFormatParameters(parameters url.Values) (formatParameters string, signedParameterList []string) {
ps := valuesSignMap{}
for key, values := range parameters {
key = strings.ToLower(key)
for _, value := range values {
ps.Add(key, value)
signedParameterList = append(signedParameterList, key)
}
}
//formatParameters = strings.ToLower(ps.Encode())
formatParameters = ps.Encode()
sort.Strings(signedParameterList)
return
}
// genFormatHeaders 生成 FormatHeaders 和 SignedHeaderList
func genFormatHeaders(headers http.Header) (formatHeaders string, signedHeaderList []string) {
hs := valuesSignMap{}
for key, values := range headers {
key = strings.ToLower(key)
for _, value := range values {
if isSignHeader(key) {
hs.Add(key, value)
signedHeaderList = append(signedHeaderList, key)
}
}
}
formatHeaders = hs.Encode()
sort.Strings(signedHeaderList)
return
}
// HMAC 签名
func calHMACDigest(key, msg, signMethod string) []byte {
var hashFunc func() hash.Hash
switch signMethod {
case "sha1":
hashFunc = sha1.New
default:
hashFunc = sha1.New
}
h := hmac.New(hashFunc, []byte(key))
h.Write([]byte(msg))
return h.Sum(nil)
}
func isSignHeader(key string) bool {
for k, v := range needSignHeaders {
if key == k && v {
return true
}
}
return strings.HasPrefix(key, privateHeaderPrefix)
}
// AuthorizationTransport 给请求增加 Authorization header
type AuthorizationTransport struct {
SecretID string
SecretKey string
SessionToken string
rwLocker sync.RWMutex
// 签名多久过期
Expire time.Duration
Transport http.RoundTripper
}
// SetCredential update the SecretID(ak), SercretKey(sk), sessiontoken
func (t *AuthorizationTransport) SetCredential(ak, sk, token string) {
t.rwLocker.Lock()
defer t.rwLocker.Unlock()
t.SecretID = ak
t.SecretKey = sk
t.SessionToken = token
}
// GetCredential get the ak, sk, token
func (t *AuthorizationTransport) GetCredential() (string, string, string) {
t.rwLocker.RLock()
defer t.rwLocker.RUnlock()
return t.SecretID, t.SecretKey, t.SessionToken
}
// RoundTrip implements the RoundTripper interface.
func (t *AuthorizationTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req = cloneRequest(req) // per RoundTrip contract
if t.Expire == time.Duration(0) {
t.Expire = defaultAuthExpire
}
ak, sk, token := t.GetCredential()
// 增加 Authorization header
authTime := NewAuthTime(t.Expire)
AddAuthorizationHeader(ak, sk, token, req, authTime)
resp, err := t.transport().RoundTrip(req)
return resp, err
}
func (t *AuthorizationTransport) transport() http.RoundTripper {
if t.Transport != nil {
return t.Transport
}
return http.DefaultTransport
}