You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

453 lines
13 KiB

4 years ago
  1. package cos
  2. import (
  3. "crypto/hmac"
  4. "crypto/sha1"
  5. "encoding/json"
  6. "fmt"
  7. "hash"
  8. "io/ioutil"
  9. "net/http"
  10. "net/url"
  11. "sort"
  12. "strings"
  13. "sync"
  14. "time"
  15. )
  16. const (
  17. sha1SignAlgorithm = "sha1"
  18. privateHeaderPrefix = "x-cos-"
  19. defaultAuthExpire = time.Hour
  20. )
  21. var (
  22. defaultCVMAuthExpire = int64(600)
  23. defaultCVMSchema = "http"
  24. defaultCVMMetaHost = "metadata.tencentyun.com"
  25. defaultCVMCredURI = "latest/meta-data/cam/security-credentials"
  26. )
  27. // 需要校验的 Headers 列表
  28. var NeedSignHeaders = map[string]bool{
  29. "host": true,
  30. "range": true,
  31. "x-cos-acl": true,
  32. "x-cos-grant-read": true,
  33. "x-cos-grant-write": true,
  34. "x-cos-grant-full-control": true,
  35. "response-content-type": true,
  36. "response-content-language": true,
  37. "response-expires": true,
  38. "response-cache-control": true,
  39. "response-content-disposition": true,
  40. "response-content-encoding": true,
  41. "cache-control": true,
  42. "content-disposition": true,
  43. "content-encoding": true,
  44. "content-type": true,
  45. "content-length": true,
  46. "content-md5": true,
  47. "expect": true,
  48. "expires": true,
  49. "x-cos-content-sha1": true,
  50. "x-cos-storage-class": true,
  51. "if-modified-since": true,
  52. "origin": true,
  53. "access-control-request-method": true,
  54. "access-control-request-headers": true,
  55. "x-cos-object-type": true,
  56. }
  57. var ciParameters = map[string]bool{
  58. "imagemogr2/": true,
  59. "watermark/": true,
  60. "imageview2/": true,
  61. }
  62. // 非线程安全,只能在进程初始化(而不是Client初始化)时做设置
  63. func SetNeedSignHeaders(key string, val bool) {
  64. NeedSignHeaders[key] = val
  65. }
  66. func safeURLEncode(s string) string {
  67. s = encodeURIComponent(s)
  68. s = strings.Replace(s, "!", "%21", -1)
  69. s = strings.Replace(s, "'", "%27", -1)
  70. s = strings.Replace(s, "(", "%28", -1)
  71. s = strings.Replace(s, ")", "%29", -1)
  72. s = strings.Replace(s, "*", "%2A", -1)
  73. return s
  74. }
  75. type valuesSignMap map[string][]string
  76. func (vs valuesSignMap) Add(key, value string) {
  77. key = strings.ToLower(key)
  78. vs[key] = append(vs[key], value)
  79. }
  80. func (vs valuesSignMap) Encode() string {
  81. var keys []string
  82. for k := range vs {
  83. keys = append(keys, k)
  84. }
  85. sort.Strings(keys)
  86. var pairs []string
  87. for _, k := range keys {
  88. items := vs[k]
  89. sort.Strings(items)
  90. for _, val := range items {
  91. pairs = append(
  92. pairs,
  93. fmt.Sprintf("%s=%s", safeURLEncode(k), safeURLEncode(val)))
  94. }
  95. }
  96. return strings.Join(pairs, "&")
  97. }
  98. // AuthTime 用于生成签名所需的 q-sign-time 和 q-key-time 相关参数
  99. type AuthTime struct {
  100. SignStartTime time.Time
  101. SignEndTime time.Time
  102. KeyStartTime time.Time
  103. KeyEndTime time.Time
  104. }
  105. // NewAuthTime 生成 AuthTime 的便捷函数
  106. //
  107. // expire: 从现在开始多久过期.
  108. func NewAuthTime(expire time.Duration) *AuthTime {
  109. signStartTime := time.Now()
  110. keyStartTime := signStartTime
  111. signEndTime := signStartTime.Add(expire)
  112. keyEndTime := signEndTime
  113. return &AuthTime{
  114. SignStartTime: signStartTime,
  115. SignEndTime: signEndTime,
  116. KeyStartTime: keyStartTime,
  117. KeyEndTime: keyEndTime,
  118. }
  119. }
  120. // signString return q-sign-time string
  121. func (a *AuthTime) signString() string {
  122. return fmt.Sprintf("%d;%d", a.SignStartTime.Unix(), a.SignEndTime.Unix())
  123. }
  124. // keyString return q-key-time string
  125. func (a *AuthTime) keyString() string {
  126. return fmt.Sprintf("%d;%d", a.KeyStartTime.Unix(), a.KeyEndTime.Unix())
  127. }
  128. // newAuthorization 通过一系列步骤生成最终需要的 Authorization 字符串
  129. func newAuthorization(secretID, secretKey string, req *http.Request, authTime *AuthTime) string {
  130. signTime := authTime.signString()
  131. keyTime := authTime.keyString()
  132. signKey := calSignKey(secretKey, keyTime)
  133. req.Header.Set("Host", req.Host)
  134. formatHeaders := *new(string)
  135. signedHeaderList := *new([]string)
  136. formatHeaders, signedHeaderList = genFormatHeaders(req.Header)
  137. formatParameters, signedParameterList := genFormatParameters(req.URL.Query())
  138. formatString := genFormatString(req.Method, *req.URL, formatParameters, formatHeaders)
  139. stringToSign := calStringToSign(sha1SignAlgorithm, keyTime, formatString)
  140. signature := calSignature(signKey, stringToSign)
  141. return genAuthorization(
  142. secretID, signTime, keyTime, signature, signedHeaderList,
  143. signedParameterList,
  144. )
  145. }
  146. // AddAuthorizationHeader 给 req 增加签名信息
  147. func AddAuthorizationHeader(secretID, secretKey string, sessionToken string, req *http.Request, authTime *AuthTime) {
  148. if secretID == "" {
  149. return
  150. }
  151. auth := newAuthorization(secretID, secretKey, req,
  152. authTime,
  153. )
  154. if len(sessionToken) > 0 {
  155. req.Header.Set("x-cos-security-token", sessionToken)
  156. }
  157. req.Header.Set("Authorization", auth)
  158. }
  159. // calSignKey 计算 SignKey
  160. func calSignKey(secretKey, keyTime string) string {
  161. digest := calHMACDigest(secretKey, keyTime, sha1SignAlgorithm)
  162. return fmt.Sprintf("%x", digest)
  163. }
  164. // calStringToSign 计算 StringToSign
  165. func calStringToSign(signAlgorithm, signTime, formatString string) string {
  166. h := sha1.New()
  167. h.Write([]byte(formatString))
  168. return fmt.Sprintf("%s\n%s\n%x\n", signAlgorithm, signTime, h.Sum(nil))
  169. }
  170. // calSignature 计算 Signature
  171. func calSignature(signKey, stringToSign string) string {
  172. digest := calHMACDigest(signKey, stringToSign, sha1SignAlgorithm)
  173. return fmt.Sprintf("%x", digest)
  174. }
  175. // genAuthorization 生成 Authorization
  176. func genAuthorization(secretID, signTime, keyTime, signature string, signedHeaderList, signedParameterList []string) string {
  177. return strings.Join([]string{
  178. "q-sign-algorithm=" + sha1SignAlgorithm,
  179. "q-ak=" + secretID,
  180. "q-sign-time=" + signTime,
  181. "q-key-time=" + keyTime,
  182. "q-header-list=" + strings.Join(signedHeaderList, ";"),
  183. "q-url-param-list=" + strings.Join(signedParameterList, ";"),
  184. "q-signature=" + signature,
  185. }, "&")
  186. }
  187. // genFormatString 生成 FormatString
  188. func genFormatString(method string, uri url.URL, formatParameters, formatHeaders string) string {
  189. formatMethod := strings.ToLower(method)
  190. formatURI := uri.Path
  191. return fmt.Sprintf("%s\n%s\n%s\n%s\n", formatMethod, formatURI,
  192. formatParameters, formatHeaders,
  193. )
  194. }
  195. // genFormatParameters 生成 FormatParameters 和 SignedParameterList
  196. // instead of the url.Values{}
  197. func genFormatParameters(parameters url.Values) (formatParameters string, signedParameterList []string) {
  198. ps := valuesSignMap{}
  199. for key, values := range parameters {
  200. key = strings.ToLower(key)
  201. for _, value := range values {
  202. if !isCIParameter(key) {
  203. ps.Add(key, value)
  204. signedParameterList = append(signedParameterList, key)
  205. }
  206. }
  207. }
  208. //formatParameters = strings.ToLower(ps.Encode())
  209. formatParameters = ps.Encode()
  210. sort.Strings(signedParameterList)
  211. return
  212. }
  213. // genFormatHeaders 生成 FormatHeaders 和 SignedHeaderList
  214. func genFormatHeaders(headers http.Header) (formatHeaders string, signedHeaderList []string) {
  215. hs := valuesSignMap{}
  216. for key, values := range headers {
  217. key = strings.ToLower(key)
  218. for _, value := range values {
  219. if isSignHeader(key) {
  220. hs.Add(key, value)
  221. signedHeaderList = append(signedHeaderList, key)
  222. }
  223. }
  224. }
  225. formatHeaders = hs.Encode()
  226. sort.Strings(signedHeaderList)
  227. return
  228. }
  229. // HMAC 签名
  230. func calHMACDigest(key, msg, signMethod string) []byte {
  231. var hashFunc func() hash.Hash
  232. switch signMethod {
  233. case "sha1":
  234. hashFunc = sha1.New
  235. default:
  236. hashFunc = sha1.New
  237. }
  238. h := hmac.New(hashFunc, []byte(key))
  239. h.Write([]byte(msg))
  240. return h.Sum(nil)
  241. }
  242. func isCIParameter(key string) bool {
  243. for k, v := range ciParameters {
  244. if strings.HasPrefix(key, k) && v {
  245. return true
  246. }
  247. }
  248. return false
  249. }
  250. func isSignHeader(key string) bool {
  251. for k, v := range NeedSignHeaders {
  252. if key == k && v {
  253. return true
  254. }
  255. }
  256. return strings.HasPrefix(key, privateHeaderPrefix)
  257. }
  258. // AuthorizationTransport 给请求增加 Authorization header
  259. type AuthorizationTransport struct {
  260. SecretID string
  261. SecretKey string
  262. SessionToken string
  263. rwLocker sync.RWMutex
  264. // 签名多久过期
  265. Expire time.Duration
  266. Transport http.RoundTripper
  267. }
  268. // SetCredential update the SecretID(ak), SercretKey(sk), sessiontoken
  269. func (t *AuthorizationTransport) SetCredential(ak, sk, token string) {
  270. t.rwLocker.Lock()
  271. defer t.rwLocker.Unlock()
  272. t.SecretID = ak
  273. t.SecretKey = sk
  274. t.SessionToken = token
  275. }
  276. // GetCredential get the ak, sk, token
  277. func (t *AuthorizationTransport) GetCredential() (string, string, string) {
  278. t.rwLocker.RLock()
  279. defer t.rwLocker.RUnlock()
  280. return t.SecretID, t.SecretKey, t.SessionToken
  281. }
  282. // RoundTrip implements the RoundTripper interface.
  283. func (t *AuthorizationTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  284. req = cloneRequest(req) // per RoundTrip contract
  285. ak, sk, token := t.GetCredential()
  286. // 增加 Authorization header
  287. authTime := NewAuthTime(defaultAuthExpire)
  288. AddAuthorizationHeader(ak, sk, token, req, authTime)
  289. resp, err := t.transport().RoundTrip(req)
  290. return resp, err
  291. }
  292. func (t *AuthorizationTransport) transport() http.RoundTripper {
  293. if t.Transport != nil {
  294. return t.Transport
  295. }
  296. return http.DefaultTransport
  297. }
  298. type CVMSecurityCredentials struct {
  299. TmpSecretId string `json:",omitempty"`
  300. TmpSecretKey string `json:",omitempty"`
  301. ExpiredTime int64 `json:",omitempty"`
  302. Expiration string `json:",omitempty"`
  303. Token string `json:",omitempty"`
  304. Code string `json:",omitempty"`
  305. }
  306. type CVMCredentialsTransport struct {
  307. RoleName string
  308. Transport http.RoundTripper
  309. secretID string
  310. secretKey string
  311. sessionToken string
  312. expiredTime int64
  313. rwLocker sync.RWMutex
  314. }
  315. func (t *CVMCredentialsTransport) GetRoles() ([]string, error) {
  316. urlname := fmt.Sprintf("%s://%s/%s", defaultCVMSchema, defaultCVMMetaHost, defaultCVMCredURI)
  317. resp, err := http.Get(urlname)
  318. if err != nil {
  319. return nil, err
  320. }
  321. defer resp.Body.Close()
  322. if resp.StatusCode < 200 || resp.StatusCode > 299 {
  323. bs, _ := ioutil.ReadAll(resp.Body)
  324. return nil, fmt.Errorf("get cvm security-credentials role failed, StatusCode: %v, Body: %v", resp.StatusCode, string(bs))
  325. }
  326. bs, err := ioutil.ReadAll(resp.Body)
  327. if err != nil {
  328. return nil, err
  329. }
  330. roles := strings.Split(strings.TrimSpace(string(bs)), "\n")
  331. if len(roles) == 0 {
  332. return nil, fmt.Errorf("get cvm security-credentials role failed, No valid cam role was found")
  333. }
  334. return roles, nil
  335. }
  336. // https://cloud.tencent.com/document/product/213/4934
  337. func (t *CVMCredentialsTransport) UpdateCredential(now int64) (string, string, string, error) {
  338. t.rwLocker.Lock()
  339. defer t.rwLocker.Unlock()
  340. if t.expiredTime > now+defaultCVMAuthExpire {
  341. return t.secretID, t.secretKey, t.sessionToken, nil
  342. }
  343. roleName := t.RoleName
  344. if roleName == "" {
  345. roles, err := t.GetRoles()
  346. if err != nil {
  347. return t.secretID, t.secretKey, t.sessionToken, err
  348. }
  349. roleName = roles[0]
  350. }
  351. urlname := fmt.Sprintf("%s://%s/%s/%s", defaultCVMSchema, defaultCVMMetaHost, defaultCVMCredURI, roleName)
  352. resp, err := http.Get(urlname)
  353. if err != nil {
  354. return t.secretID, t.secretKey, t.sessionToken, err
  355. }
  356. defer resp.Body.Close()
  357. if resp.StatusCode < 200 || resp.StatusCode > 299 {
  358. bs, _ := ioutil.ReadAll(resp.Body)
  359. return t.secretID, t.secretKey, t.sessionToken, fmt.Errorf("call cvm security-credentials failed, StatusCode: %v, Body: %v", resp.StatusCode, string(bs))
  360. }
  361. var cred CVMSecurityCredentials
  362. err = json.NewDecoder(resp.Body).Decode(&cred)
  363. if err != nil {
  364. return t.secretID, t.secretKey, t.sessionToken, err
  365. }
  366. if cred.Code != "Success" {
  367. return t.secretID, t.secretKey, t.sessionToken, fmt.Errorf("call cvm security-credentials failed, Code:%v", cred.Code)
  368. }
  369. t.secretID, t.secretKey, t.sessionToken, t.expiredTime = cred.TmpSecretId, cred.TmpSecretKey, cred.Token, cred.ExpiredTime
  370. return t.secretID, t.secretKey, t.sessionToken, nil
  371. }
  372. func (t *CVMCredentialsTransport) GetCredential() (string, string, string, error) {
  373. now := time.Now().Unix()
  374. t.rwLocker.RLock()
  375. // 提前 defaultCVMAuthExpire 获取重新获取临时密钥
  376. if t.expiredTime <= now+defaultCVMAuthExpire {
  377. expiredTime := t.expiredTime
  378. t.rwLocker.RUnlock()
  379. secretID, secretKey, secretToken, err := t.UpdateCredential(now)
  380. // 获取临时密钥失败但密钥未过期
  381. if err != nil && now < expiredTime {
  382. err = nil
  383. }
  384. return secretID, secretKey, secretToken, err
  385. }
  386. defer t.rwLocker.RUnlock()
  387. return t.secretID, t.secretKey, t.sessionToken, nil
  388. }
  389. func (t *CVMCredentialsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  390. ak, sk, token, err := t.GetCredential()
  391. if err != nil {
  392. return nil, err
  393. }
  394. req = cloneRequest(req)
  395. // 增加 Authorization header
  396. authTime := NewAuthTime(defaultAuthExpire)
  397. AddAuthorizationHeader(ak, sk, token, req, authTime)
  398. resp, err := t.transport().RoundTrip(req)
  399. return resp, err
  400. }
  401. func (t *CVMCredentialsTransport) transport() http.RoundTripper {
  402. if t.Transport != nil {
  403. return t.Transport
  404. }
  405. return http.DefaultTransport
  406. }