You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

453 lines
13 KiB

package cos
import (
"crypto/hmac"
"crypto/sha1"
"encoding/json"
"fmt"
"hash"
"io/ioutil"
"net/http"
"net/url"
"sort"
"strings"
"sync"
"time"
)
const (
sha1SignAlgorithm = "sha1"
privateHeaderPrefix = "x-cos-"
defaultAuthExpire = time.Hour
)
var (
defaultCVMAuthExpire = int64(600)
defaultCVMSchema = "http"
defaultCVMMetaHost = "metadata.tencentyun.com"
defaultCVMCredURI = "latest/meta-data/cam/security-credentials"
)
// 需要校验的 Headers 列表
var NeedSignHeaders = map[string]bool{
"host": true,
"range": true,
"x-cos-acl": true,
"x-cos-grant-read": true,
"x-cos-grant-write": true,
"x-cos-grant-full-control": true,
"response-content-type": true,
"response-content-language": true,
"response-expires": true,
"response-cache-control": true,
"response-content-disposition": true,
"response-content-encoding": true,
"cache-control": true,
"content-disposition": true,
"content-encoding": true,
"content-type": true,
"content-length": true,
"content-md5": true,
"expect": true,
"expires": true,
"x-cos-content-sha1": true,
"x-cos-storage-class": true,
"if-modified-since": true,
"origin": true,
"access-control-request-method": true,
"access-control-request-headers": true,
"x-cos-object-type": true,
}
var ciParameters = map[string]bool{
"imagemogr2/": true,
"watermark/": true,
"imageview2/": true,
}
// 非线程安全,只能在进程初始化(而不是Client初始化)时做设置
func SetNeedSignHeaders(key string, val bool) {
NeedSignHeaders[key] = val
}
func safeURLEncode(s string) string {
s = encodeURIComponent(s)
s = strings.Replace(s, "!", "%21", -1)
s = strings.Replace(s, "'", "%27", -1)
s = strings.Replace(s, "(", "%28", -1)
s = strings.Replace(s, ")", "%29", -1)
s = strings.Replace(s, "*", "%2A", -1)
return s
}
type valuesSignMap map[string][]string
func (vs valuesSignMap) Add(key, value string) {
key = strings.ToLower(key)
vs[key] = append(vs[key], value)
}
func (vs valuesSignMap) Encode() string {
var keys []string
for k := range vs {
keys = append(keys, k)
}
sort.Strings(keys)
var pairs []string
for _, k := range keys {
items := vs[k]
sort.Strings(items)
for _, val := range items {
pairs = append(
pairs,
fmt.Sprintf("%s=%s", safeURLEncode(k), safeURLEncode(val)))
}
}
return strings.Join(pairs, "&")
}
// AuthTime 用于生成签名所需的 q-sign-time 和 q-key-time 相关参数
type AuthTime struct {
SignStartTime time.Time
SignEndTime time.Time
KeyStartTime time.Time
KeyEndTime time.Time
}
// NewAuthTime 生成 AuthTime 的便捷函数
//
// expire: 从现在开始多久过期.
func NewAuthTime(expire time.Duration) *AuthTime {
signStartTime := time.Now()
keyStartTime := signStartTime
signEndTime := signStartTime.Add(expire)
keyEndTime := signEndTime
return &AuthTime{
SignStartTime: signStartTime,
SignEndTime: signEndTime,
KeyStartTime: keyStartTime,
KeyEndTime: keyEndTime,
}
}
// signString return q-sign-time string
func (a *AuthTime) signString() string {
return fmt.Sprintf("%d;%d", a.SignStartTime.Unix(), a.SignEndTime.Unix())
}
// keyString return q-key-time string
func (a *AuthTime) keyString() string {
return fmt.Sprintf("%d;%d", a.KeyStartTime.Unix(), a.KeyEndTime.Unix())
}
// newAuthorization 通过一系列步骤生成最终需要的 Authorization 字符串
func newAuthorization(secretID, secretKey string, req *http.Request, authTime *AuthTime) string {
signTime := authTime.signString()
keyTime := authTime.keyString()
signKey := calSignKey(secretKey, keyTime)
req.Header.Set("Host", req.Host)
formatHeaders := *new(string)
signedHeaderList := *new([]string)
formatHeaders, signedHeaderList = genFormatHeaders(req.Header)
formatParameters, signedParameterList := genFormatParameters(req.URL.Query())
formatString := genFormatString(req.Method, *req.URL, formatParameters, formatHeaders)
stringToSign := calStringToSign(sha1SignAlgorithm, keyTime, formatString)
signature := calSignature(signKey, stringToSign)
return genAuthorization(
secretID, signTime, keyTime, signature, signedHeaderList,
signedParameterList,
)
}
// AddAuthorizationHeader 给 req 增加签名信息
func AddAuthorizationHeader(secretID, secretKey string, sessionToken string, req *http.Request, authTime *AuthTime) {
if secretID == "" {
return
}
auth := newAuthorization(secretID, secretKey, req,
authTime,
)
if len(sessionToken) > 0 {
req.Header.Set("x-cos-security-token", sessionToken)
}
req.Header.Set("Authorization", auth)
}
// calSignKey 计算 SignKey
func calSignKey(secretKey, keyTime string) string {
digest := calHMACDigest(secretKey, keyTime, sha1SignAlgorithm)
return fmt.Sprintf("%x", digest)
}
// calStringToSign 计算 StringToSign
func calStringToSign(signAlgorithm, signTime, formatString string) string {
h := sha1.New()
h.Write([]byte(formatString))
return fmt.Sprintf("%s\n%s\n%x\n", signAlgorithm, signTime, h.Sum(nil))
}
// calSignature 计算 Signature
func calSignature(signKey, stringToSign string) string {
digest := calHMACDigest(signKey, stringToSign, sha1SignAlgorithm)
return fmt.Sprintf("%x", digest)
}
// genAuthorization 生成 Authorization
func genAuthorization(secretID, signTime, keyTime, signature string, signedHeaderList, signedParameterList []string) string {
return strings.Join([]string{
"q-sign-algorithm=" + sha1SignAlgorithm,
"q-ak=" + secretID,
"q-sign-time=" + signTime,
"q-key-time=" + keyTime,
"q-header-list=" + strings.Join(signedHeaderList, ";"),
"q-url-param-list=" + strings.Join(signedParameterList, ";"),
"q-signature=" + signature,
}, "&")
}
// genFormatString 生成 FormatString
func genFormatString(method string, uri url.URL, formatParameters, formatHeaders string) string {
formatMethod := strings.ToLower(method)
formatURI := uri.Path
return fmt.Sprintf("%s\n%s\n%s\n%s\n", formatMethod, formatURI,
formatParameters, formatHeaders,
)
}
// genFormatParameters 生成 FormatParameters 和 SignedParameterList
// instead of the url.Values{}
func genFormatParameters(parameters url.Values) (formatParameters string, signedParameterList []string) {
ps := valuesSignMap{}
for key, values := range parameters {
key = strings.ToLower(key)
for _, value := range values {
if !isCIParameter(key) {
ps.Add(key, value)
signedParameterList = append(signedParameterList, key)
}
}
}
//formatParameters = strings.ToLower(ps.Encode())
formatParameters = ps.Encode()
sort.Strings(signedParameterList)
return
}
// genFormatHeaders 生成 FormatHeaders 和 SignedHeaderList
func genFormatHeaders(headers http.Header) (formatHeaders string, signedHeaderList []string) {
hs := valuesSignMap{}
for key, values := range headers {
key = strings.ToLower(key)
for _, value := range values {
if isSignHeader(key) {
hs.Add(key, value)
signedHeaderList = append(signedHeaderList, key)
}
}
}
formatHeaders = hs.Encode()
sort.Strings(signedHeaderList)
return
}
// HMAC 签名
func calHMACDigest(key, msg, signMethod string) []byte {
var hashFunc func() hash.Hash
switch signMethod {
case "sha1":
hashFunc = sha1.New
default:
hashFunc = sha1.New
}
h := hmac.New(hashFunc, []byte(key))
h.Write([]byte(msg))
return h.Sum(nil)
}
func isCIParameter(key string) bool {
for k, v := range ciParameters {
if strings.HasPrefix(key, k) && v {
return true
}
}
return false
}
func isSignHeader(key string) bool {
for k, v := range NeedSignHeaders {
if key == k && v {
return true
}
}
return strings.HasPrefix(key, privateHeaderPrefix)
}
// AuthorizationTransport 给请求增加 Authorization header
type AuthorizationTransport struct {
SecretID string
SecretKey string
SessionToken string
rwLocker sync.RWMutex
// 签名多久过期
Expire time.Duration
Transport http.RoundTripper
}
// SetCredential update the SecretID(ak), SercretKey(sk), sessiontoken
func (t *AuthorizationTransport) SetCredential(ak, sk, token string) {
t.rwLocker.Lock()
defer t.rwLocker.Unlock()
t.SecretID = ak
t.SecretKey = sk
t.SessionToken = token
}
// GetCredential get the ak, sk, token
func (t *AuthorizationTransport) GetCredential() (string, string, string) {
t.rwLocker.RLock()
defer t.rwLocker.RUnlock()
return t.SecretID, t.SecretKey, t.SessionToken
}
// RoundTrip implements the RoundTripper interface.
func (t *AuthorizationTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req = cloneRequest(req) // per RoundTrip contract
ak, sk, token := t.GetCredential()
// 增加 Authorization header
authTime := NewAuthTime(defaultAuthExpire)
AddAuthorizationHeader(ak, sk, token, req, authTime)
resp, err := t.transport().RoundTrip(req)
return resp, err
}
func (t *AuthorizationTransport) transport() http.RoundTripper {
if t.Transport != nil {
return t.Transport
}
return http.DefaultTransport
}
type CVMSecurityCredentials struct {
TmpSecretId string `json:",omitempty"`
TmpSecretKey string `json:",omitempty"`
ExpiredTime int64 `json:",omitempty"`
Expiration string `json:",omitempty"`
Token string `json:",omitempty"`
Code string `json:",omitempty"`
}
type CVMCredentialsTransport struct {
RoleName string
Transport http.RoundTripper
secretID string
secretKey string
sessionToken string
expiredTime int64
rwLocker sync.RWMutex
}
func (t *CVMCredentialsTransport) GetRoles() ([]string, error) {
urlname := fmt.Sprintf("%s://%s/%s", defaultCVMSchema, defaultCVMMetaHost, defaultCVMCredURI)
resp, err := http.Get(urlname)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode > 299 {
bs, _ := ioutil.ReadAll(resp.Body)
return nil, fmt.Errorf("get cvm security-credentials role failed, StatusCode: %v, Body: %v", resp.StatusCode, string(bs))
}
bs, err := ioutil.ReadAll(resp.Body)
if err != nil {
return nil, err
}
roles := strings.Split(strings.TrimSpace(string(bs)), "\n")
if len(roles) == 0 {
return nil, fmt.Errorf("get cvm security-credentials role failed, No valid cam role was found")
}
return roles, nil
}
// https://cloud.tencent.com/document/product/213/4934
func (t *CVMCredentialsTransport) UpdateCredential(now int64) (string, string, string, error) {
t.rwLocker.Lock()
defer t.rwLocker.Unlock()
if t.expiredTime > now+defaultCVMAuthExpire {
return t.secretID, t.secretKey, t.sessionToken, nil
}
roleName := t.RoleName
if roleName == "" {
roles, err := t.GetRoles()
if err != nil {
return t.secretID, t.secretKey, t.sessionToken, err
}
roleName = roles[0]
}
urlname := fmt.Sprintf("%s://%s/%s/%s", defaultCVMSchema, defaultCVMMetaHost, defaultCVMCredURI, roleName)
resp, err := http.Get(urlname)
if err != nil {
return t.secretID, t.secretKey, t.sessionToken, err
}
defer resp.Body.Close()
if resp.StatusCode < 200 || resp.StatusCode > 299 {
bs, _ := ioutil.ReadAll(resp.Body)
return t.secretID, t.secretKey, t.sessionToken, fmt.Errorf("call cvm security-credentials failed, StatusCode: %v, Body: %v", resp.StatusCode, string(bs))
}
var cred CVMSecurityCredentials
err = json.NewDecoder(resp.Body).Decode(&cred)
if err != nil {
return t.secretID, t.secretKey, t.sessionToken, err
}
if cred.Code != "Success" {
return t.secretID, t.secretKey, t.sessionToken, fmt.Errorf("call cvm security-credentials failed, Code:%v", cred.Code)
}
t.secretID, t.secretKey, t.sessionToken, t.expiredTime = cred.TmpSecretId, cred.TmpSecretKey, cred.Token, cred.ExpiredTime
return t.secretID, t.secretKey, t.sessionToken, nil
}
func (t *CVMCredentialsTransport) GetCredential() (string, string, string, error) {
now := time.Now().Unix()
t.rwLocker.RLock()
// 提前 defaultCVMAuthExpire 获取重新获取临时密钥
if t.expiredTime <= now+defaultCVMAuthExpire {
expiredTime := t.expiredTime
t.rwLocker.RUnlock()
secretID, secretKey, secretToken, err := t.UpdateCredential(now)
// 获取临时密钥失败但密钥未过期
if err != nil && now < expiredTime {
err = nil
}
return secretID, secretKey, secretToken, err
}
defer t.rwLocker.RUnlock()
return t.secretID, t.secretKey, t.sessionToken, nil
}
func (t *CVMCredentialsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
ak, sk, token, err := t.GetCredential()
if err != nil {
return nil, err
}
req = cloneRequest(req)
// 增加 Authorization header
authTime := NewAuthTime(defaultAuthExpire)
AddAuthorizationHeader(ak, sk, token, req, authTime)
resp, err := t.transport().RoundTrip(req)
return resp, err
}
func (t *CVMCredentialsTransport) transport() http.RoundTripper {
if t.Transport != nil {
return t.Transport
}
return http.DefaultTransport
}