You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

320 lines
8.7 KiB

4 years ago
  1. package cos
  2. import (
  3. "crypto/hmac"
  4. "crypto/sha1"
  5. "fmt"
  6. "hash"
  7. "net/http"
  8. "net/url"
  9. "sort"
  10. "strings"
  11. "sync"
  12. "time"
  13. )
  14. const sha1SignAlgorithm = "sha1"
  15. const privateHeaderPrefix = "x-cos-"
  16. const defaultAuthExpire = time.Hour
  17. // 需要校验的 Headers 列表
  18. var needSignHeaders = map[string]bool{
  19. "host": true,
  20. "range": true,
  21. "x-cos-acl": true,
  22. "x-cos-grant-read": true,
  23. "x-cos-grant-write": true,
  24. "x-cos-grant-full-control": true,
  25. "response-content-type": true,
  26. "response-content-language": true,
  27. "response-expires": true,
  28. "response-cache-control": true,
  29. "response-content-disposition": true,
  30. "response-content-encoding": true,
  31. "cache-control": true,
  32. "content-disposition": true,
  33. "content-encoding": true,
  34. "content-type": true,
  35. "content-length": true,
  36. "content-md5": true,
  37. "expect": true,
  38. "expires": true,
  39. "x-cos-content-sha1": true,
  40. "x-cos-storage-class": true,
  41. "if-modified-since": true,
  42. "origin": true,
  43. "access-control-request-method": true,
  44. "access-control-request-headers": true,
  45. "x-cos-object-type": true,
  46. }
  47. var ciParameters = map[string]bool{
  48. "imagemogr2/": true,
  49. "watermark/": true,
  50. "imageview2/": true,
  51. }
  52. func safeURLEncode(s string) string {
  53. s = encodeURIComponent(s)
  54. s = strings.Replace(s, "!", "%21", -1)
  55. s = strings.Replace(s, "'", "%27", -1)
  56. s = strings.Replace(s, "(", "%28", -1)
  57. s = strings.Replace(s, ")", "%29", -1)
  58. s = strings.Replace(s, "*", "%2A", -1)
  59. return s
  60. }
  61. type valuesSignMap map[string][]string
  62. func (vs valuesSignMap) Add(key, value string) {
  63. key = strings.ToLower(key)
  64. vs[key] = append(vs[key], value)
  65. }
  66. func (vs valuesSignMap) Encode() string {
  67. var keys []string
  68. for k := range vs {
  69. keys = append(keys, k)
  70. }
  71. sort.Strings(keys)
  72. var pairs []string
  73. for _, k := range keys {
  74. items := vs[k]
  75. sort.Strings(items)
  76. for _, val := range items {
  77. pairs = append(
  78. pairs,
  79. fmt.Sprintf("%s=%s", safeURLEncode(k), safeURLEncode(val)))
  80. }
  81. }
  82. return strings.Join(pairs, "&")
  83. }
  84. // AuthTime 用于生成签名所需的 q-sign-time 和 q-key-time 相关参数
  85. type AuthTime struct {
  86. SignStartTime time.Time
  87. SignEndTime time.Time
  88. KeyStartTime time.Time
  89. KeyEndTime time.Time
  90. }
  91. // NewAuthTime 生成 AuthTime 的便捷函数
  92. //
  93. // expire: 从现在开始多久过期.
  94. func NewAuthTime(expire time.Duration) *AuthTime {
  95. signStartTime := time.Now()
  96. keyStartTime := signStartTime
  97. signEndTime := signStartTime.Add(expire)
  98. keyEndTime := signEndTime
  99. return &AuthTime{
  100. SignStartTime: signStartTime,
  101. SignEndTime: signEndTime,
  102. KeyStartTime: keyStartTime,
  103. KeyEndTime: keyEndTime,
  104. }
  105. }
  106. // signString return q-sign-time string
  107. func (a *AuthTime) signString() string {
  108. return fmt.Sprintf("%d;%d", a.SignStartTime.Unix(), a.SignEndTime.Unix())
  109. }
  110. // keyString return q-key-time string
  111. func (a *AuthTime) keyString() string {
  112. return fmt.Sprintf("%d;%d", a.KeyStartTime.Unix(), a.KeyEndTime.Unix())
  113. }
  114. // newAuthorization 通过一系列步骤生成最终需要的 Authorization 字符串
  115. func newAuthorization(secretID, secretKey string, req *http.Request, authTime *AuthTime) string {
  116. signTime := authTime.signString()
  117. keyTime := authTime.keyString()
  118. signKey := calSignKey(secretKey, keyTime)
  119. req.Header.Set("Host", req.Host)
  120. formatHeaders := *new(string)
  121. signedHeaderList := *new([]string)
  122. formatHeaders, signedHeaderList = genFormatHeaders(req.Header)
  123. formatParameters, signedParameterList := genFormatParameters(req.URL.Query())
  124. formatString := genFormatString(req.Method, *req.URL, formatParameters, formatHeaders)
  125. stringToSign := calStringToSign(sha1SignAlgorithm, keyTime, formatString)
  126. signature := calSignature(signKey, stringToSign)
  127. return genAuthorization(
  128. secretID, signTime, keyTime, signature, signedHeaderList,
  129. signedParameterList,
  130. )
  131. }
  132. // AddAuthorizationHeader 给 req 增加签名信息
  133. func AddAuthorizationHeader(secretID, secretKey string, sessionToken string, req *http.Request, authTime *AuthTime) {
  134. if secretID == "" {
  135. return
  136. }
  137. auth := newAuthorization(secretID, secretKey, req,
  138. authTime,
  139. )
  140. if len(sessionToken) > 0 {
  141. req.Header.Set("x-cos-security-token", sessionToken)
  142. }
  143. req.Header.Set("Authorization", auth)
  144. }
  145. // calSignKey 计算 SignKey
  146. func calSignKey(secretKey, keyTime string) string {
  147. digest := calHMACDigest(secretKey, keyTime, sha1SignAlgorithm)
  148. return fmt.Sprintf("%x", digest)
  149. }
  150. // calStringToSign 计算 StringToSign
  151. func calStringToSign(signAlgorithm, signTime, formatString string) string {
  152. h := sha1.New()
  153. h.Write([]byte(formatString))
  154. return fmt.Sprintf("%s\n%s\n%x\n", signAlgorithm, signTime, h.Sum(nil))
  155. }
  156. // calSignature 计算 Signature
  157. func calSignature(signKey, stringToSign string) string {
  158. digest := calHMACDigest(signKey, stringToSign, sha1SignAlgorithm)
  159. return fmt.Sprintf("%x", digest)
  160. }
  161. // genAuthorization 生成 Authorization
  162. func genAuthorization(secretID, signTime, keyTime, signature string, signedHeaderList, signedParameterList []string) string {
  163. return strings.Join([]string{
  164. "q-sign-algorithm=" + sha1SignAlgorithm,
  165. "q-ak=" + secretID,
  166. "q-sign-time=" + signTime,
  167. "q-key-time=" + keyTime,
  168. "q-header-list=" + strings.Join(signedHeaderList, ";"),
  169. "q-url-param-list=" + strings.Join(signedParameterList, ";"),
  170. "q-signature=" + signature,
  171. }, "&")
  172. }
  173. // genFormatString 生成 FormatString
  174. func genFormatString(method string, uri url.URL, formatParameters, formatHeaders string) string {
  175. formatMethod := strings.ToLower(method)
  176. formatURI := uri.Path
  177. return fmt.Sprintf("%s\n%s\n%s\n%s\n", formatMethod, formatURI,
  178. formatParameters, formatHeaders,
  179. )
  180. }
  181. // genFormatParameters 生成 FormatParameters 和 SignedParameterList
  182. // instead of the url.Values{}
  183. func genFormatParameters(parameters url.Values) (formatParameters string, signedParameterList []string) {
  184. ps := valuesSignMap{}
  185. for key, values := range parameters {
  186. key = strings.ToLower(key)
  187. for _, value := range values {
  188. if !isCIParameter(key) {
  189. ps.Add(key, value)
  190. signedParameterList = append(signedParameterList, key)
  191. }
  192. }
  193. }
  194. //formatParameters = strings.ToLower(ps.Encode())
  195. formatParameters = ps.Encode()
  196. sort.Strings(signedParameterList)
  197. return
  198. }
  199. // genFormatHeaders 生成 FormatHeaders 和 SignedHeaderList
  200. func genFormatHeaders(headers http.Header) (formatHeaders string, signedHeaderList []string) {
  201. hs := valuesSignMap{}
  202. for key, values := range headers {
  203. key = strings.ToLower(key)
  204. for _, value := range values {
  205. if isSignHeader(key) {
  206. hs.Add(key, value)
  207. signedHeaderList = append(signedHeaderList, key)
  208. }
  209. }
  210. }
  211. formatHeaders = hs.Encode()
  212. sort.Strings(signedHeaderList)
  213. return
  214. }
  215. // HMAC 签名
  216. func calHMACDigest(key, msg, signMethod string) []byte {
  217. var hashFunc func() hash.Hash
  218. switch signMethod {
  219. case "sha1":
  220. hashFunc = sha1.New
  221. default:
  222. hashFunc = sha1.New
  223. }
  224. h := hmac.New(hashFunc, []byte(key))
  225. h.Write([]byte(msg))
  226. return h.Sum(nil)
  227. }
  228. func isCIParameter(key string) bool {
  229. for k, v := range ciParameters {
  230. if strings.HasPrefix(key, k) && v {
  231. return true
  232. }
  233. }
  234. return false
  235. }
  236. func isSignHeader(key string) bool {
  237. for k, v := range needSignHeaders {
  238. if key == k && v {
  239. return true
  240. }
  241. }
  242. return strings.HasPrefix(key, privateHeaderPrefix)
  243. }
  244. // AuthorizationTransport 给请求增加 Authorization header
  245. type AuthorizationTransport struct {
  246. SecretID string
  247. SecretKey string
  248. SessionToken string
  249. rwLocker sync.RWMutex
  250. // 签名多久过期
  251. Expire time.Duration
  252. Transport http.RoundTripper
  253. }
  254. // SetCredential update the SecretID(ak), SercretKey(sk), sessiontoken
  255. func (t *AuthorizationTransport) SetCredential(ak, sk, token string) {
  256. t.rwLocker.Lock()
  257. defer t.rwLocker.Unlock()
  258. t.SecretID = ak
  259. t.SecretKey = sk
  260. t.SessionToken = token
  261. }
  262. // GetCredential get the ak, sk, token
  263. func (t *AuthorizationTransport) GetCredential() (string, string, string) {
  264. t.rwLocker.RLock()
  265. defer t.rwLocker.RUnlock()
  266. return t.SecretID, t.SecretKey, t.SessionToken
  267. }
  268. // RoundTrip implements the RoundTripper interface.
  269. func (t *AuthorizationTransport) RoundTrip(req *http.Request) (*http.Response, error) {
  270. req = cloneRequest(req) // per RoundTrip contract
  271. ak, sk, token := t.GetCredential()
  272. // 增加 Authorization header
  273. authTime := NewAuthTime(defaultAuthExpire)
  274. AddAuthorizationHeader(ak, sk, token, req, authTime)
  275. resp, err := t.transport().RoundTrip(req)
  276. return resp, err
  277. }
  278. func (t *AuthorizationTransport) transport() http.RoundTripper {
  279. if t.Transport != nil {
  280. return t.Transport
  281. }
  282. return http.DefaultTransport
  283. }