auth_test.go 3.3 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143
  1. // Copyright The OpenTelemetry Authors
  2. // SPDX-License-Identifier: Apache-2.0
  3. package f5cloudexporter
  4. import (
  5. "fmt"
  6. "net/http"
  7. "net/http/httptest"
  8. "strings"
  9. "testing"
  10. "time"
  11. "github.com/stretchr/testify/assert"
  12. "golang.org/x/oauth2"
  13. )
  14. type ErrorRoundTripper struct{}
  15. func (ert *ErrorRoundTripper) RoundTrip(_ *http.Request) (*http.Response, error) {
  16. return nil, fmt.Errorf("error")
  17. }
  18. func TestF5CloudAuthRoundTripper_RoundTrip(t *testing.T) {
  19. validTokenSource := createMockTokenSource()
  20. source := "tests"
  21. defaultRoundTripper := (http.RoundTripper)(http.DefaultTransport.(*http.Transport).Clone())
  22. errorRoundTripper := &ErrorRoundTripper{}
  23. tests := []struct {
  24. name string
  25. rt http.RoundTripper
  26. token oauth2.TokenSource
  27. shouldError bool
  28. }{
  29. {
  30. name: "Test valid token source",
  31. rt: defaultRoundTripper,
  32. token: validTokenSource,
  33. shouldError: false,
  34. },
  35. {
  36. name: "Test invalid token source",
  37. rt: defaultRoundTripper,
  38. token: &InvalidTokenSource{},
  39. shouldError: true,
  40. },
  41. {
  42. name: "Test error in next round tripper",
  43. rt: errorRoundTripper,
  44. token: validTokenSource,
  45. shouldError: true,
  46. },
  47. }
  48. for _, tt := range tests {
  49. t.Run(tt.name, func(t *testing.T) {
  50. server := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
  51. assert.Equal(t, "Bearer test_access_token", r.Header.Get("Authorization"))
  52. assert.Equal(t, "tests", r.Header.Get(sourceHeader))
  53. }))
  54. defer server.Close()
  55. rt, err := newF5CloudAuthRoundTripper(tt.token, source, tt.rt)
  56. assert.NoError(t, err)
  57. req, err := http.NewRequest("POST", server.URL, strings.NewReader(""))
  58. assert.NoError(t, err)
  59. res, err := rt.RoundTrip(req)
  60. if tt.shouldError {
  61. assert.Nil(t, res)
  62. assert.Error(t, err)
  63. return
  64. }
  65. assert.NoError(t, err)
  66. assert.Equal(t, res.StatusCode, 200)
  67. })
  68. }
  69. }
  70. func TestCreateF5CloudAuthRoundTripperWithToken(t *testing.T) {
  71. defaultRoundTripper := (http.RoundTripper)(http.DefaultTransport.(*http.Transport).Clone())
  72. token := createMockTokenSource()
  73. source := "test"
  74. tests := []struct {
  75. name string
  76. token oauth2.TokenSource
  77. source string
  78. rt http.RoundTripper
  79. shouldError bool
  80. }{
  81. {
  82. name: "success_case",
  83. token: token,
  84. source: source,
  85. rt: defaultRoundTripper,
  86. shouldError: false,
  87. },
  88. {
  89. name: "no_token_provided_error",
  90. token: nil,
  91. source: source,
  92. rt: defaultRoundTripper,
  93. shouldError: true,
  94. },
  95. {
  96. name: "no_source_provided_error",
  97. token: token,
  98. source: "",
  99. rt: defaultRoundTripper,
  100. shouldError: true,
  101. },
  102. }
  103. for _, tt := range tests {
  104. t.Run(tt.name, func(t *testing.T) {
  105. _, err := newF5CloudAuthRoundTripper(tt.token, tt.source, tt.rt)
  106. if tt.shouldError {
  107. assert.Error(t, err)
  108. return
  109. }
  110. assert.NoError(t, err)
  111. })
  112. }
  113. }
  114. func createMockTokenSource() oauth2.TokenSource {
  115. tkn := &oauth2.Token{
  116. AccessToken: "test_access_token",
  117. TokenType: "",
  118. RefreshToken: "",
  119. Expiry: time.Time{},
  120. }
  121. return oauth2.StaticTokenSource(tkn)
  122. }
  123. type InvalidTokenSource struct{}
  124. func (ts *InvalidTokenSource) Token() (*oauth2.Token, error) {
  125. return nil, fmt.Errorf("bad TokenSource for testing")
  126. }