You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

263 lines
7.6 KiB

  1. package cos
  2. import (
  3. "crypto/hmac"
  4. "crypto/sha1"
  5. "fmt"
  6. "hash"
  7. "net/http"
  8. "net/url"
  9. "sort"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. const sha1SignAlgorithm = "sha1"
  15. const privateHeaderPrefix = "x-cos-"
  16. const defaultAuthExpire = time.Hour
  17. // 需要校验的 Headers 列表
  18. var needSignHeaders = map[string]bool{
  19. "host": true,
  20. "range": true,
  21. "x-cos-acl": true,
  22. "x-cos-grant-read": true,
  23. "x-cos-grant-write": true,
  24. "x-cos-grant-full-control": true,
  25. "response-content-type": true,
  26. "response-content-language": true,
  27. "response-expires": true,
  28. "response-cache-control": true,
  29. "response-content-disposition": true,
  30. "response-content-encoding": true,
  31. "cache-control": true,
  32. "content-disposition": true,
  33. "content-encoding": true,
  34. "content-type": true,
  35. "content-length": true,
  36. "content-md5": true,
  37. "expect": true,
  38. "expires": true,
  39. "x-cos-content-sha1": true,
  40. "x-cos-storage-class": true,
  41. "if-modified-since": true,
  42. "origin": true,
  43. "access-control-request-method": true,
  44. "access-control-request-headers": true,
  45. "x-cos-object-type": true,
  46. }
  47. // AuthTime 用于生成签名所需的 q-sign-time 和 q-key-time 相关参数
  48. type AuthTime struct {
  49. SignStartTime time.Time
  50. SignEndTime time.Time
  51. KeyStartTime time.Time
  52. KeyEndTime time.Time
  53. }
  54. // NewAuthTime 生成 AuthTime 的便捷函数
  55. //
  56. // expire: 从现在开始多久过期.
  57. func NewAuthTime(expire time.Duration) *AuthTime {
  58. signStartTime := time.Now()
  59. keyStartTime := signStartTime
  60. signEndTime := signStartTime.Add(expire)
  61. keyEndTime := signEndTime
  62. return &AuthTime{
  63. SignStartTime: signStartTime,
  64. SignEndTime: signEndTime,
  65. KeyStartTime: keyStartTime,
  66. KeyEndTime: keyEndTime,
  67. }
  68. }
  69. // signString return q-sign-time string
  70. func (a *AuthTime) signString() string {
  71. return fmt.Sprintf("%d;%d", a.SignStartTime.Unix(), a.SignEndTime.Unix())
  72. }
  73. // keyString return q-key-time string
  74. func (a *AuthTime) keyString() string {
  75. return fmt.Sprintf("%d;%d", a.KeyStartTime.Unix(), a.KeyEndTime.Unix())
  76. }
  77. // newAuthorization 通过一系列步骤生成最终需要的 Authorization 字符串
  78. func newAuthorization(secretID, secretKey string, req *http.Request, authTime *AuthTime) string {
  79. signTime := authTime.signString()
  80. keyTime := authTime.keyString()
  81. signKey := calSignKey(secretKey, keyTime)
  82. formatHeaders := *new(string)
  83. signedHeaderList := *new([]string)
  84. formatHeaders, signedHeaderList = genFormatHeaders(req.Header)
  85. formatParameters, signedParameterList := genFormatParameters(req.URL.Query())
  86. formatString := genFormatString(req.Method, *req.URL, formatParameters, formatHeaders)
  87. stringToSign := calStringToSign(sha1SignAlgorithm, keyTime, formatString)
  88. signature := calSignature(signKey, stringToSign)
  89. return genAuthorization(
  90. secretID, signTime, keyTime, signature, signedHeaderList,
  91. signedParameterList,
  92. )
  93. }
  94. // AddAuthorizationHeader 给 req 增加签名信息
  95. func AddAuthorizationHeader(secretID, secretKey string, sessionToken string, req *http.Request, authTime *AuthTime) {
  96. auth := newAuthorization(secretID, secretKey, req,
  97. authTime,
  98. )
  99. if len(sessionToken) > 0 {
  100. req.Header.Set("x-cos-security-token", sessionToken)
  101. }
  102. req.Header.Set("Authorization", auth)
  103. }
  104. // calSignKey 计算 SignKey
  105. func calSignKey(secretKey, keyTime string) string {
  106. digest := calHMACDigest(secretKey, keyTime, sha1SignAlgorithm)
  107. return fmt.Sprintf("%x", digest)
  108. }
  109. // calStringToSign 计算 StringToSign
  110. func calStringToSign(signAlgorithm, signTime, formatString string) string {
  111. h := sha1.New()
  112. h.Write([]byte(formatString))
  113. return fmt.Sprintf("%s\n%s\n%x\n", signAlgorithm, signTime, h.Sum(nil))
  114. }
  115. // calSignature 计算 Signature
  116. func calSignature(signKey, stringToSign string) string {
  117. digest := calHMACDigest(signKey, stringToSign, sha1SignAlgorithm)
  118. return fmt.Sprintf("%x", digest)
  119. }
  120. // genAuthorization 生成 Authorization
  121. func genAuthorization(secretID, signTime, keyTime, signature string, signedHeaderList, signedParameterList []string) string {
  122. return strings.Join([]string{
  123. "q-sign-algorithm=" + sha1SignAlgorithm,
  124. "q-ak=" + secretID,
  125. "q-sign-time=" + signTime,
  126. "q-key-time=" + keyTime,
  127. "q-header-list=" + strings.Join(signedHeaderList, ";"),
  128. "q-url-param-list=" + strings.Join(signedParameterList, ";"),
  129. "q-signature=" + signature,
  130. }, "&")
  131. }
  132. // genFormatString 生成 FormatString
  133. func genFormatString(method string, uri url.URL, formatParameters, formatHeaders string) string {
  134. formatMethod := strings.ToLower(method)
  135. formatURI := uri.Path
  136. return fmt.Sprintf("%s\n%s\n%s\n%s\n", formatMethod, formatURI,
  137. formatParameters, formatHeaders,
  138. )
  139. }
  140. // genFormatParameters 生成 FormatParameters 和 SignedParameterList
  141. func genFormatParameters(parameters url.Values) (formatParameters string, signedParameterList []string) {
  142. ps := url.Values{}
  143. for key, values := range parameters {
  144. for _, value := range values {
  145. key = strings.ToLower(key)
  146. ps.Add(key, value)
  147. signedParameterList = append(signedParameterList, key)
  148. }
  149. }
  150. //formatParameters = strings.ToLower(ps.Encode())
  151. formatParameters = ps.Encode()
  152. sort.Strings(signedParameterList)
  153. return
  154. }
  155. // genFormatHeaders 生成 FormatHeaders 和 SignedHeaderList
  156. func genFormatHeaders(headers http.Header) (formatHeaders string, signedHeaderList []string) {
  157. hs := url.Values{}
  158. for key, values := range headers {
  159. for _, value := range values {
  160. key = strings.ToLower(key)
  161. if isSignHeader(key) {
  162. hs.Add(key, value)
  163. signedHeaderList = append(signedHeaderList, key)
  164. }
  165. }
  166. }
  167. formatHeaders = hs.Encode()
  168. sort.Strings(signedHeaderList)
  169. return
  170. }
  171. // HMAC 签名
  172. func calHMACDigest(key, msg, signMethod string) []byte {
  173. var hashFunc func() hash.Hash
  174. switch signMethod {
  175. case "sha1":
  176. hashFunc = sha1.New
  177. default:
  178. hashFunc = sha1.New
  179. }
  180. h := hmac.New(hashFunc, []byte(key))
  181. h.Write([]byte(msg))
  182. return h.Sum(nil)
  183. }
  184. func isSignHeader(key string) bool {
  185. for k, v := range needSignHeaders {
  186. if key == k && v {
  187. return true
  188. }
  189. }
  190. return strings.HasPrefix(key, privateHeaderPrefix)
  191. }
  192. // AuthorizationTransport 给请求增加 Authorization header
  193. type AuthorizationTransport struct {
  194. SecretID string
  195. SecretKey string
  196. SessionToken string
  197. rwLocker sync.RWMutex
  198. // 签名多久过期
  199. Expire time.Duration
  200. Transport http.RoundTripper
  201. }
  202. // SetCredential update the SecretID(ak), SercretKey(sk), sessiontoken
  203. func (t *AuthorizationTransport) SetCredential(ak, sk, token string) {
  204. t.rwLocker.Lock()
  205. defer t.rwLocker.Unlock()
  206. t.SecretID = ak
  207. t.SecretKey = sk
  208. t.SessionToken = token
  209. }
  210. // GetCredential get the ak, sk, token
  211. func (t *AuthorizationTransport) GetCredential() (string, string, string) {
  212. t.rwLocker.RLock()
  213. defer t.rwLocker.RUnlock()
  214. return t.SecretID, t.SecretKey, t.SessionToken
  215. }
  216. // RoundTrip implements the RoundTripper interface.
  217. func (t *AuthorizationTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  218. req = cloneRequest(req) // per RoundTrip contract
  219. if t.Expire == time.Duration(0) {
  220. t.Expire = defaultAuthExpire
  221. }
  222. ak, sk, token := t.GetCredential()
  223. // 增加 Authorization header
  224. authTime := NewAuthTime(t.Expire)
  225. AddAuthorizationHeader(ak, sk, token, req, authTime)
  226. resp, err := t.transport().RoundTrip(req)
  227. return resp, err
  228. }
  229. func (t *AuthorizationTransport) transport() http.RoundTripper {
  230. if t.Transport != nil {
  231. return t.Transport
  232. }
  233. return http.DefaultTransport
  234. }