|
|
package cos
import ( "crypto/hmac" "crypto/sha1" "encoding/json" "fmt" "hash" "io/ioutil" "net/http" "net/url" "sort" "strings" "sync" "time" )
const ( sha1SignAlgorithm = "sha1" privateHeaderPrefix = "x-cos-" defaultAuthExpire = time.Hour )
var ( defaultCVMAuthExpire = int64(600) defaultCVMSchema = "http" defaultCVMMetaHost = "metadata.tencentyun.com" defaultCVMCredURI = "latest/meta-data/cam/security-credentials" )
// 需要校验的 Headers 列表
var NeedSignHeaders = map[string]bool{ "host": true, "range": true, "x-cos-acl": true, "x-cos-grant-read": true, "x-cos-grant-write": true, "x-cos-grant-full-control": true, "response-content-type": true, "response-content-language": true, "response-expires": true, "response-cache-control": true, "response-content-disposition": true, "response-content-encoding": true, "cache-control": true, "content-disposition": true, "content-encoding": true, "content-type": true, "content-length": true, "content-md5": true, "expect": true, "expires": true, "x-cos-content-sha1": true, "x-cos-storage-class": true, "if-modified-since": true, "origin": true, "access-control-request-method": true, "access-control-request-headers": true, "x-cos-object-type": true, }
var ciParameters = map[string]bool{ "imagemogr2/": true, "watermark/": true, "imageview2/": true, }
// 非线程安全,只能在进程初始化(而不是Client初始化)时做设置
func SetNeedSignHeaders(key string, val bool) { NeedSignHeaders[key] = val }
func safeURLEncode(s string) string { s = encodeURIComponent(s) s = strings.Replace(s, "!", "%21", -1) s = strings.Replace(s, "'", "%27", -1) s = strings.Replace(s, "(", "%28", -1) s = strings.Replace(s, ")", "%29", -1) s = strings.Replace(s, "*", "%2A", -1) return s }
type valuesSignMap map[string][]string
func (vs valuesSignMap) Add(key, value string) { key = strings.ToLower(key) vs[key] = append(vs[key], value) }
func (vs valuesSignMap) Encode() string { var keys []string for k := range vs { keys = append(keys, k) } sort.Strings(keys)
var pairs []string for _, k := range keys { items := vs[k] sort.Strings(items) for _, val := range items { pairs = append( pairs, fmt.Sprintf("%s=%s", safeURLEncode(k), safeURLEncode(val))) } } return strings.Join(pairs, "&") }
// AuthTime 用于生成签名所需的 q-sign-time 和 q-key-time 相关参数
type AuthTime struct { SignStartTime time.Time SignEndTime time.Time KeyStartTime time.Time KeyEndTime time.Time }
// NewAuthTime 生成 AuthTime 的便捷函数
//
// expire: 从现在开始多久过期.
func NewAuthTime(expire time.Duration) *AuthTime { signStartTime := time.Now() keyStartTime := signStartTime signEndTime := signStartTime.Add(expire) keyEndTime := signEndTime return &AuthTime{ SignStartTime: signStartTime, SignEndTime: signEndTime, KeyStartTime: keyStartTime, KeyEndTime: keyEndTime, } }
// signString return q-sign-time string
func (a *AuthTime) signString() string { return fmt.Sprintf("%d;%d", a.SignStartTime.Unix(), a.SignEndTime.Unix()) }
// keyString return q-key-time string
func (a *AuthTime) keyString() string { return fmt.Sprintf("%d;%d", a.KeyStartTime.Unix(), a.KeyEndTime.Unix()) }
// newAuthorization 通过一系列步骤生成最终需要的 Authorization 字符串
func newAuthorization(secretID, secretKey string, req *http.Request, authTime *AuthTime) string { signTime := authTime.signString() keyTime := authTime.keyString() signKey := calSignKey(secretKey, keyTime)
req.Header.Set("Host", req.Host) formatHeaders := *new(string) signedHeaderList := *new([]string) formatHeaders, signedHeaderList = genFormatHeaders(req.Header) formatParameters, signedParameterList := genFormatParameters(req.URL.Query()) formatString := genFormatString(req.Method, *req.URL, formatParameters, formatHeaders)
stringToSign := calStringToSign(sha1SignAlgorithm, keyTime, formatString) signature := calSignature(signKey, stringToSign)
return genAuthorization( secretID, signTime, keyTime, signature, signedHeaderList, signedParameterList, ) }
// AddAuthorizationHeader 给 req 增加签名信息
func AddAuthorizationHeader(secretID, secretKey string, sessionToken string, req *http.Request, authTime *AuthTime) { if secretID == "" { return }
auth := newAuthorization(secretID, secretKey, req, authTime, ) if len(sessionToken) > 0 { req.Header.Set("x-cos-security-token", sessionToken) } req.Header.Set("Authorization", auth) }
// calSignKey 计算 SignKey
func calSignKey(secretKey, keyTime string) string { digest := calHMACDigest(secretKey, keyTime, sha1SignAlgorithm) return fmt.Sprintf("%x", digest) }
// calStringToSign 计算 StringToSign
func calStringToSign(signAlgorithm, signTime, formatString string) string { h := sha1.New() h.Write([]byte(formatString)) return fmt.Sprintf("%s\n%s\n%x\n", signAlgorithm, signTime, h.Sum(nil)) }
// calSignature 计算 Signature
func calSignature(signKey, stringToSign string) string { digest := calHMACDigest(signKey, stringToSign, sha1SignAlgorithm) return fmt.Sprintf("%x", digest) }
// genAuthorization 生成 Authorization
func genAuthorization(secretID, signTime, keyTime, signature string, signedHeaderList, signedParameterList []string) string { return strings.Join([]string{ "q-sign-algorithm=" + sha1SignAlgorithm, "q-ak=" + secretID, "q-sign-time=" + signTime, "q-key-time=" + keyTime, "q-header-list=" + strings.Join(signedHeaderList, ";"), "q-url-param-list=" + strings.Join(signedParameterList, ";"), "q-signature=" + signature, }, "&") }
// genFormatString 生成 FormatString
func genFormatString(method string, uri url.URL, formatParameters, formatHeaders string) string { formatMethod := strings.ToLower(method) formatURI := uri.Path
return fmt.Sprintf("%s\n%s\n%s\n%s\n", formatMethod, formatURI, formatParameters, formatHeaders, ) }
// genFormatParameters 生成 FormatParameters 和 SignedParameterList
// instead of the url.Values{}
func genFormatParameters(parameters url.Values) (formatParameters string, signedParameterList []string) { ps := valuesSignMap{} for key, values := range parameters { key = strings.ToLower(key) for _, value := range values { if !isCIParameter(key) { ps.Add(key, value) signedParameterList = append(signedParameterList, key) } } } //formatParameters = strings.ToLower(ps.Encode())
formatParameters = ps.Encode() sort.Strings(signedParameterList) return }
// genFormatHeaders 生成 FormatHeaders 和 SignedHeaderList
func genFormatHeaders(headers http.Header) (formatHeaders string, signedHeaderList []string) { hs := valuesSignMap{} for key, values := range headers { key = strings.ToLower(key) for _, value := range values { if isSignHeader(key) { hs.Add(key, value) signedHeaderList = append(signedHeaderList, key) } } } formatHeaders = hs.Encode() sort.Strings(signedHeaderList) return }
// HMAC 签名
func calHMACDigest(key, msg, signMethod string) []byte { var hashFunc func() hash.Hash switch signMethod { case "sha1": hashFunc = sha1.New default: hashFunc = sha1.New } h := hmac.New(hashFunc, []byte(key)) h.Write([]byte(msg)) return h.Sum(nil) }
func isCIParameter(key string) bool { for k, v := range ciParameters { if strings.HasPrefix(key, k) && v { return true } } return false }
func isSignHeader(key string) bool { for k, v := range NeedSignHeaders { if key == k && v { return true } } return strings.HasPrefix(key, privateHeaderPrefix) }
// AuthorizationTransport 给请求增加 Authorization header
type AuthorizationTransport struct { SecretID string SecretKey string SessionToken string rwLocker sync.RWMutex // 签名多久过期
Expire time.Duration Transport http.RoundTripper }
// SetCredential update the SecretID(ak), SercretKey(sk), sessiontoken
func (t *AuthorizationTransport) SetCredential(ak, sk, token string) { t.rwLocker.Lock() defer t.rwLocker.Unlock() t.SecretID = ak t.SecretKey = sk t.SessionToken = token }
// GetCredential get the ak, sk, token
func (t *AuthorizationTransport) GetCredential() (string, string, string) { t.rwLocker.RLock() defer t.rwLocker.RUnlock() return t.SecretID, t.SecretKey, t.SessionToken }
// RoundTrip implements the RoundTripper interface.
func (t *AuthorizationTransport) RoundTrip(req *http.Request) (*http.Response, error) { req = cloneRequest(req) // per RoundTrip contract
ak, sk, token := t.GetCredential() // 增加 Authorization header
authTime := NewAuthTime(defaultAuthExpire) AddAuthorizationHeader(ak, sk, token, req, authTime)
resp, err := t.transport().RoundTrip(req) return resp, err }
func (t *AuthorizationTransport) transport() http.RoundTripper { if t.Transport != nil { return t.Transport } return http.DefaultTransport }
type CVMSecurityCredentials struct { TmpSecretId string `json:",omitempty"` TmpSecretKey string `json:",omitempty"` ExpiredTime int64 `json:",omitempty"` Expiration string `json:",omitempty"` Token string `json:",omitempty"` Code string `json:",omitempty"` }
type CVMCredentialsTransport struct { RoleName string Transport http.RoundTripper secretID string secretKey string sessionToken string expiredTime int64 rwLocker sync.RWMutex }
func (t *CVMCredentialsTransport) GetRoles() ([]string, error) { urlname := fmt.Sprintf("%s://%s/%s", defaultCVMSchema, defaultCVMMetaHost, defaultCVMCredURI) resp, err := http.Get(urlname) if err != nil { return nil, err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode > 299 { bs, _ := ioutil.ReadAll(resp.Body) return nil, fmt.Errorf("get cvm security-credentials role failed, StatusCode: %v, Body: %v", resp.StatusCode, string(bs)) } bs, err := ioutil.ReadAll(resp.Body) if err != nil { return nil, err } roles := strings.Split(strings.TrimSpace(string(bs)), "\n") if len(roles) == 0 { return nil, fmt.Errorf("get cvm security-credentials role failed, No valid cam role was found") } return roles, nil }
// https://cloud.tencent.com/document/product/213/4934
func (t *CVMCredentialsTransport) UpdateCredential(now int64) (string, string, string, error) { t.rwLocker.Lock() defer t.rwLocker.Unlock() if t.expiredTime > now+defaultCVMAuthExpire { return t.secretID, t.secretKey, t.sessionToken, nil } roleName := t.RoleName if roleName == "" { roles, err := t.GetRoles() if err != nil { return t.secretID, t.secretKey, t.sessionToken, err } roleName = roles[0] } urlname := fmt.Sprintf("%s://%s/%s/%s", defaultCVMSchema, defaultCVMMetaHost, defaultCVMCredURI, roleName) resp, err := http.Get(urlname) if err != nil { return t.secretID, t.secretKey, t.sessionToken, err } defer resp.Body.Close() if resp.StatusCode < 200 || resp.StatusCode > 299 { bs, _ := ioutil.ReadAll(resp.Body) return t.secretID, t.secretKey, t.sessionToken, fmt.Errorf("call cvm security-credentials failed, StatusCode: %v, Body: %v", resp.StatusCode, string(bs)) } var cred CVMSecurityCredentials err = json.NewDecoder(resp.Body).Decode(&cred) if err != nil { return t.secretID, t.secretKey, t.sessionToken, err } if cred.Code != "Success" { return t.secretID, t.secretKey, t.sessionToken, fmt.Errorf("call cvm security-credentials failed, Code:%v", cred.Code) } t.secretID, t.secretKey, t.sessionToken, t.expiredTime = cred.TmpSecretId, cred.TmpSecretKey, cred.Token, cred.ExpiredTime return t.secretID, t.secretKey, t.sessionToken, nil }
func (t *CVMCredentialsTransport) GetCredential() (string, string, string, error) { now := time.Now().Unix() t.rwLocker.RLock() // 提前 defaultCVMAuthExpire 获取重新获取临时密钥
if t.expiredTime <= now+defaultCVMAuthExpire { expiredTime := t.expiredTime t.rwLocker.RUnlock() secretID, secretKey, secretToken, err := t.UpdateCredential(now) // 获取临时密钥失败但密钥未过期
if err != nil && now < expiredTime { err = nil } return secretID, secretKey, secretToken, err } defer t.rwLocker.RUnlock() return t.secretID, t.secretKey, t.sessionToken, nil }
func (t *CVMCredentialsTransport) RoundTrip(req *http.Request) (*http.Response, error) { ak, sk, token, err := t.GetCredential() if err != nil { return nil, err } req = cloneRequest(req) // 增加 Authorization header
authTime := NewAuthTime(defaultAuthExpire) AddAuthorizationHeader(ak, sk, token, req, authTime)
resp, err := t.transport().RoundTrip(req) return resp, err }
func (t *CVMCredentialsTransport) transport() http.RoundTripper { if t.Transport != nil { return t.Transport } return http.DefaultTransport }
|