You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

448 lines
13 KiB

4 years ago
  1. package cos
  2. import (
  3. "crypto/hmac"
  4. "crypto/sha1"
  5. "encoding/json"
  6. "fmt"
  7. "hash"
  8. "io/ioutil"
  9. "net/http"
  10. "net/url"
  11. "sort"
  12. "strings"
  13. "sync"
  14. "time"
  15. )
  16. const (
  17. sha1SignAlgorithm = "sha1"
  18. privateHeaderPrefix = "x-cos-"
  19. defaultAuthExpire = time.Hour
  20. )
  21. var (
  22. defaultCVMAuthExpire = int64(600)
  23. defaultCVMSchema = "http"
  24. defaultCVMMetaHost = "metadata.tencentyun.com"
  25. defaultCVMCredURI = "latest/meta-data/cam/security-credentials"
  26. )
  27. // 需要校验的 Headers 列表
  28. var needSignHeaders = map[string]bool{
  29. "host": true,
  30. "range": true,
  31. "x-cos-acl": true,
  32. "x-cos-grant-read": true,
  33. "x-cos-grant-write": true,
  34. "x-cos-grant-full-control": true,
  35. "response-content-type": true,
  36. "response-content-language": true,
  37. "response-expires": true,
  38. "response-cache-control": true,
  39. "response-content-disposition": true,
  40. "response-content-encoding": true,
  41. "cache-control": true,
  42. "content-disposition": true,
  43. "content-encoding": true,
  44. "content-type": true,
  45. "content-length": true,
  46. "content-md5": true,
  47. "expect": true,
  48. "expires": true,
  49. "x-cos-content-sha1": true,
  50. "x-cos-storage-class": true,
  51. "if-modified-since": true,
  52. "origin": true,
  53. "access-control-request-method": true,
  54. "access-control-request-headers": true,
  55. "x-cos-object-type": true,
  56. }
  57. var ciParameters = map[string]bool{
  58. "imagemogr2/": true,
  59. "watermark/": true,
  60. "imageview2/": true,
  61. }
  62. func safeURLEncode(s string) string {
  63. s = encodeURIComponent(s)
  64. s = strings.Replace(s, "!", "%21", -1)
  65. s = strings.Replace(s, "'", "%27", -1)
  66. s = strings.Replace(s, "(", "%28", -1)
  67. s = strings.Replace(s, ")", "%29", -1)
  68. s = strings.Replace(s, "*", "%2A", -1)
  69. return s
  70. }
  71. type valuesSignMap map[string][]string
  72. func (vs valuesSignMap) Add(key, value string) {
  73. key = strings.ToLower(key)
  74. vs[key] = append(vs[key], value)
  75. }
  76. func (vs valuesSignMap) Encode() string {
  77. var keys []string
  78. for k := range vs {
  79. keys = append(keys, k)
  80. }
  81. sort.Strings(keys)
  82. var pairs []string
  83. for _, k := range keys {
  84. items := vs[k]
  85. sort.Strings(items)
  86. for _, val := range items {
  87. pairs = append(
  88. pairs,
  89. fmt.Sprintf("%s=%s", safeURLEncode(k), safeURLEncode(val)))
  90. }
  91. }
  92. return strings.Join(pairs, "&")
  93. }
  94. // AuthTime 用于生成签名所需的 q-sign-time 和 q-key-time 相关参数
  95. type AuthTime struct {
  96. SignStartTime time.Time
  97. SignEndTime time.Time
  98. KeyStartTime time.Time
  99. KeyEndTime time.Time
  100. }
  101. // NewAuthTime 生成 AuthTime 的便捷函数
  102. //
  103. // expire: 从现在开始多久过期.
  104. func NewAuthTime(expire time.Duration) *AuthTime {
  105. signStartTime := time.Now()
  106. keyStartTime := signStartTime
  107. signEndTime := signStartTime.Add(expire)
  108. keyEndTime := signEndTime
  109. return &AuthTime{
  110. SignStartTime: signStartTime,
  111. SignEndTime: signEndTime,
  112. KeyStartTime: keyStartTime,
  113. KeyEndTime: keyEndTime,
  114. }
  115. }
  116. // signString return q-sign-time string
  117. func (a *AuthTime) signString() string {
  118. return fmt.Sprintf("%d;%d", a.SignStartTime.Unix(), a.SignEndTime.Unix())
  119. }
  120. // keyString return q-key-time string
  121. func (a *AuthTime) keyString() string {
  122. return fmt.Sprintf("%d;%d", a.KeyStartTime.Unix(), a.KeyEndTime.Unix())
  123. }
  124. // newAuthorization 通过一系列步骤生成最终需要的 Authorization 字符串
  125. func newAuthorization(secretID, secretKey string, req *http.Request, authTime *AuthTime) string {
  126. signTime := authTime.signString()
  127. keyTime := authTime.keyString()
  128. signKey := calSignKey(secretKey, keyTime)
  129. req.Header.Set("Host", req.Host)
  130. formatHeaders := *new(string)
  131. signedHeaderList := *new([]string)
  132. formatHeaders, signedHeaderList = genFormatHeaders(req.Header)
  133. formatParameters, signedParameterList := genFormatParameters(req.URL.Query())
  134. formatString := genFormatString(req.Method, *req.URL, formatParameters, formatHeaders)
  135. stringToSign := calStringToSign(sha1SignAlgorithm, keyTime, formatString)
  136. signature := calSignature(signKey, stringToSign)
  137. return genAuthorization(
  138. secretID, signTime, keyTime, signature, signedHeaderList,
  139. signedParameterList,
  140. )
  141. }
  142. // AddAuthorizationHeader 给 req 增加签名信息
  143. func AddAuthorizationHeader(secretID, secretKey string, sessionToken string, req *http.Request, authTime *AuthTime) {
  144. if secretID == "" {
  145. return
  146. }
  147. auth := newAuthorization(secretID, secretKey, req,
  148. authTime,
  149. )
  150. if len(sessionToken) > 0 {
  151. req.Header.Set("x-cos-security-token", sessionToken)
  152. }
  153. req.Header.Set("Authorization", auth)
  154. }
  155. // calSignKey 计算 SignKey
  156. func calSignKey(secretKey, keyTime string) string {
  157. digest := calHMACDigest(secretKey, keyTime, sha1SignAlgorithm)
  158. return fmt.Sprintf("%x", digest)
  159. }
  160. // calStringToSign 计算 StringToSign
  161. func calStringToSign(signAlgorithm, signTime, formatString string) string {
  162. h := sha1.New()
  163. h.Write([]byte(formatString))
  164. return fmt.Sprintf("%s\n%s\n%x\n", signAlgorithm, signTime, h.Sum(nil))
  165. }
  166. // calSignature 计算 Signature
  167. func calSignature(signKey, stringToSign string) string {
  168. digest := calHMACDigest(signKey, stringToSign, sha1SignAlgorithm)
  169. return fmt.Sprintf("%x", digest)
  170. }
  171. // genAuthorization 生成 Authorization
  172. func genAuthorization(secretID, signTime, keyTime, signature string, signedHeaderList, signedParameterList []string) string {
  173. return strings.Join([]string{
  174. "q-sign-algorithm=" + sha1SignAlgorithm,
  175. "q-ak=" + secretID,
  176. "q-sign-time=" + signTime,
  177. "q-key-time=" + keyTime,
  178. "q-header-list=" + strings.Join(signedHeaderList, ";"),
  179. "q-url-param-list=" + strings.Join(signedParameterList, ";"),
  180. "q-signature=" + signature,
  181. }, "&")
  182. }
  183. // genFormatString 生成 FormatString
  184. func genFormatString(method string, uri url.URL, formatParameters, formatHeaders string) string {
  185. formatMethod := strings.ToLower(method)
  186. formatURI := uri.Path
  187. return fmt.Sprintf("%s\n%s\n%s\n%s\n", formatMethod, formatURI,
  188. formatParameters, formatHeaders,
  189. )
  190. }
  191. // genFormatParameters 生成 FormatParameters 和 SignedParameterList
  192. // instead of the url.Values{}
  193. func genFormatParameters(parameters url.Values) (formatParameters string, signedParameterList []string) {
  194. ps := valuesSignMap{}
  195. for key, values := range parameters {
  196. key = strings.ToLower(key)
  197. for _, value := range values {
  198. if !isCIParameter(key) {
  199. ps.Add(key, value)
  200. signedParameterList = append(signedParameterList, key)
  201. }
  202. }
  203. }
  204. //formatParameters = strings.ToLower(ps.Encode())
  205. formatParameters = ps.Encode()
  206. sort.Strings(signedParameterList)
  207. return
  208. }
  209. // genFormatHeaders 生成 FormatHeaders 和 SignedHeaderList
  210. func genFormatHeaders(headers http.Header) (formatHeaders string, signedHeaderList []string) {
  211. hs := valuesSignMap{}
  212. for key, values := range headers {
  213. key = strings.ToLower(key)
  214. for _, value := range values {
  215. if isSignHeader(key) {
  216. hs.Add(key, value)
  217. signedHeaderList = append(signedHeaderList, key)
  218. }
  219. }
  220. }
  221. formatHeaders = hs.Encode()
  222. sort.Strings(signedHeaderList)
  223. return
  224. }
  225. // HMAC 签名
  226. func calHMACDigest(key, msg, signMethod string) []byte {
  227. var hashFunc func() hash.Hash
  228. switch signMethod {
  229. case "sha1":
  230. hashFunc = sha1.New
  231. default:
  232. hashFunc = sha1.New
  233. }
  234. h := hmac.New(hashFunc, []byte(key))
  235. h.Write([]byte(msg))
  236. return h.Sum(nil)
  237. }
  238. func isCIParameter(key string) bool {
  239. for k, v := range ciParameters {
  240. if strings.HasPrefix(key, k) && v {
  241. return true
  242. }
  243. }
  244. return false
  245. }
  246. func isSignHeader(key string) bool {
  247. for k, v := range needSignHeaders {
  248. if key == k && v {
  249. return true
  250. }
  251. }
  252. return strings.HasPrefix(key, privateHeaderPrefix)
  253. }
  254. // AuthorizationTransport 给请求增加 Authorization header
  255. type AuthorizationTransport struct {
  256. SecretID string
  257. SecretKey string
  258. SessionToken string
  259. rwLocker sync.RWMutex
  260. // 签名多久过期
  261. Expire time.Duration
  262. Transport http.RoundTripper
  263. }
  264. // SetCredential update the SecretID(ak), SercretKey(sk), sessiontoken
  265. func (t *AuthorizationTransport) SetCredential(ak, sk, token string) {
  266. t.rwLocker.Lock()
  267. defer t.rwLocker.Unlock()
  268. t.SecretID = ak
  269. t.SecretKey = sk
  270. t.SessionToken = token
  271. }
  272. // GetCredential get the ak, sk, token
  273. func (t *AuthorizationTransport) GetCredential() (string, string, string) {
  274. t.rwLocker.RLock()
  275. defer t.rwLocker.RUnlock()
  276. return t.SecretID, t.SecretKey, t.SessionToken
  277. }
  278. // RoundTrip implements the RoundTripper interface.
  279. func (t *AuthorizationTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  280. req = cloneRequest(req) // per RoundTrip contract
  281. ak, sk, token := t.GetCredential()
  282. // 增加 Authorization header
  283. authTime := NewAuthTime(defaultAuthExpire)
  284. AddAuthorizationHeader(ak, sk, token, req, authTime)
  285. resp, err := t.transport().RoundTrip(req)
  286. return resp, err
  287. }
  288. func (t *AuthorizationTransport) transport() http.RoundTripper {
  289. if t.Transport != nil {
  290. return t.Transport
  291. }
  292. return http.DefaultTransport
  293. }
  294. type CVMSecurityCredentials struct {
  295. TmpSecretId string `json:",omitempty"`
  296. TmpSecretKey string `json:",omitempty"`
  297. ExpiredTime int64 `json:",omitempty"`
  298. Expiration string `json:",omitempty"`
  299. Token string `json:",omitempty"`
  300. Code string `json:",omitempty"`
  301. }
  302. type CVMCredentialsTransport struct {
  303. RoleName string
  304. Transport http.RoundTripper
  305. secretID string
  306. secretKey string
  307. sessionToken string
  308. expiredTime int64
  309. rwLocker sync.RWMutex
  310. }
  311. func (t *CVMCredentialsTransport) GetRoles() ([]string, error) {
  312. urlname := fmt.Sprintf("%s://%s/%s", defaultCVMSchema, defaultCVMMetaHost, defaultCVMCredURI)
  313. resp, err := http.Get(urlname)
  314. if err != nil {
  315. return nil, err
  316. }
  317. defer resp.Body.Close()
  318. if resp.StatusCode < 200 || resp.StatusCode > 299 {
  319. bs, _ := ioutil.ReadAll(resp.Body)
  320. return nil, fmt.Errorf("get cvm security-credentials role failed, StatusCode: %v, Body: %v", resp.StatusCode, string(bs))
  321. }
  322. bs, err := ioutil.ReadAll(resp.Body)
  323. if err != nil {
  324. return nil, err
  325. }
  326. roles := strings.Split(strings.TrimSpace(string(bs)), "\n")
  327. if len(roles) == 0 {
  328. return nil, fmt.Errorf("get cvm security-credentials role failed, No valid cam role was found")
  329. }
  330. return roles, nil
  331. }
  332. // https://cloud.tencent.com/document/product/213/4934
  333. func (t *CVMCredentialsTransport) UpdateCredential(now int64) (string, string, string, error) {
  334. t.rwLocker.Lock()
  335. defer t.rwLocker.Unlock()
  336. if t.expiredTime > now+defaultCVMAuthExpire {
  337. return t.secretID, t.secretKey, t.sessionToken, nil
  338. }
  339. roleName := t.RoleName
  340. if roleName == "" {
  341. roles, err := t.GetRoles()
  342. if err != nil {
  343. return t.secretID, t.secretKey, t.sessionToken, err
  344. }
  345. roleName = roles[0]
  346. }
  347. urlname := fmt.Sprintf("%s://%s/%s/%s", defaultCVMSchema, defaultCVMMetaHost, defaultCVMCredURI, roleName)
  348. resp, err := http.Get(urlname)
  349. if err != nil {
  350. return t.secretID, t.secretKey, t.sessionToken, err
  351. }
  352. defer resp.Body.Close()
  353. if resp.StatusCode < 200 || resp.StatusCode > 299 {
  354. bs, _ := ioutil.ReadAll(resp.Body)
  355. return t.secretID, t.secretKey, t.sessionToken, fmt.Errorf("call cvm security-credentials failed, StatusCode: %v, Body: %v", resp.StatusCode, string(bs))
  356. }
  357. var cred CVMSecurityCredentials
  358. err = json.NewDecoder(resp.Body).Decode(&cred)
  359. if err != nil {
  360. return t.secretID, t.secretKey, t.sessionToken, err
  361. }
  362. if cred.Code != "Success" {
  363. return t.secretID, t.secretKey, t.sessionToken, fmt.Errorf("call cvm security-credentials failed, Code:%v", cred.Code)
  364. }
  365. t.secretID, t.secretKey, t.sessionToken, t.expiredTime = cred.TmpSecretId, cred.TmpSecretKey, cred.Token, cred.ExpiredTime
  366. return t.secretID, t.secretKey, t.sessionToken, nil
  367. }
  368. func (t *CVMCredentialsTransport) GetCredential() (string, string, string, error) {
  369. now := time.Now().Unix()
  370. t.rwLocker.RLock()
  371. // 提前 defaultCVMAuthExpire 获取重新获取临时密钥
  372. if t.expiredTime <= now+defaultCVMAuthExpire {
  373. expiredTime := t.expiredTime
  374. t.rwLocker.RUnlock()
  375. secretID, secretKey, secretToken, err := t.UpdateCredential(now)
  376. // 获取临时密钥失败但密钥未过期
  377. if err != nil && now < expiredTime {
  378. err = nil
  379. }
  380. return secretID, secretKey, secretToken, err
  381. }
  382. defer t.rwLocker.RUnlock()
  383. return t.secretID, t.secretKey, t.sessionToken, nil
  384. }
  385. func (t *CVMCredentialsTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  386. ak, sk, token, err := t.GetCredential()
  387. if err != nil {
  388. return nil, err
  389. }
  390. req = cloneRequest(req)
  391. // 增加 Authorization header
  392. authTime := NewAuthTime(defaultAuthExpire)
  393. AddAuthorizationHeader(ak, sk, token, req, authTime)
  394. resp, err := t.transport().RoundTrip(req)
  395. return resp, err
  396. }
  397. func (t *CVMCredentialsTransport) transport() http.RoundTripper {
  398. if t.Transport != nil {
  399. return t.Transport
  400. }
  401. return http.DefaultTransport
  402. }