add service
All checks were successful
Deploy Smart Search Backend / deploy (push) Successful in 1m47s
All checks were successful
Deploy Smart Search Backend / deploy (push) Successful in 1m47s
This commit is contained in:
@@ -49,14 +49,14 @@ func (h *AuthHandler) Login(ctx context.Context, req *pb.LoginRequest) (*pb.Logi
|
||||
}
|
||||
|
||||
func (h *AuthHandler) Refresh(ctx context.Context, req *pb.RefreshRequest) (*pb.RefreshResponse, error) {
|
||||
accessToken, err := h.authService.Refresh(ctx, req.RefreshToken)
|
||||
accessToken, refreshToken, err := h.authService.Refresh(ctx, req.RefreshToken)
|
||||
if err != nil {
|
||||
return nil, errors.ToGRPCError(err, h.logger, "AuthService.Refresh")
|
||||
}
|
||||
|
||||
return &pb.RefreshResponse{
|
||||
AccessToken: accessToken,
|
||||
RefreshToken: req.RefreshToken,
|
||||
RefreshToken: refreshToken,
|
||||
}, nil
|
||||
}
|
||||
|
||||
|
||||
@@ -32,7 +32,7 @@ type AuthServiceMock struct {
|
||||
beforeLogoutCounter uint64
|
||||
LogoutMock mAuthServiceMockLogout
|
||||
|
||||
funcRefresh func(ctx context.Context, refreshToken string) (s1 string, err error)
|
||||
funcRefresh func(ctx context.Context, refreshToken string) (newAccessToken string, newRefreshToken string, err error)
|
||||
funcRefreshOrigin string
|
||||
inspectFuncRefresh func(ctx context.Context, refreshToken string)
|
||||
afterRefreshCounter uint64
|
||||
@@ -899,8 +899,9 @@ type AuthServiceMockRefreshParamPtrs struct {
|
||||
|
||||
// AuthServiceMockRefreshResults contains results of the AuthService.Refresh
|
||||
type AuthServiceMockRefreshResults struct {
|
||||
s1 string
|
||||
err error
|
||||
newAccessToken string
|
||||
newRefreshToken string
|
||||
err error
|
||||
}
|
||||
|
||||
// AuthServiceMockRefreshOrigins contains origins of expectations of the AuthService.Refresh
|
||||
@@ -1003,7 +1004,7 @@ func (mmRefresh *mAuthServiceMockRefresh) Inspect(f func(ctx context.Context, re
|
||||
}
|
||||
|
||||
// Return sets up results that will be returned by AuthService.Refresh
|
||||
func (mmRefresh *mAuthServiceMockRefresh) Return(s1 string, err error) *AuthServiceMock {
|
||||
func (mmRefresh *mAuthServiceMockRefresh) Return(newAccessToken string, newRefreshToken string, err error) *AuthServiceMock {
|
||||
if mmRefresh.mock.funcRefresh != nil {
|
||||
mmRefresh.mock.t.Fatalf("AuthServiceMock.Refresh mock is already set by Set")
|
||||
}
|
||||
@@ -1011,13 +1012,13 @@ func (mmRefresh *mAuthServiceMockRefresh) Return(s1 string, err error) *AuthServ
|
||||
if mmRefresh.defaultExpectation == nil {
|
||||
mmRefresh.defaultExpectation = &AuthServiceMockRefreshExpectation{mock: mmRefresh.mock}
|
||||
}
|
||||
mmRefresh.defaultExpectation.results = &AuthServiceMockRefreshResults{s1, err}
|
||||
mmRefresh.defaultExpectation.results = &AuthServiceMockRefreshResults{newAccessToken, newRefreshToken, err}
|
||||
mmRefresh.defaultExpectation.returnOrigin = minimock.CallerInfo(1)
|
||||
return mmRefresh.mock
|
||||
}
|
||||
|
||||
// Set uses given function f to mock the AuthService.Refresh method
|
||||
func (mmRefresh *mAuthServiceMockRefresh) Set(f func(ctx context.Context, refreshToken string) (s1 string, err error)) *AuthServiceMock {
|
||||
func (mmRefresh *mAuthServiceMockRefresh) Set(f func(ctx context.Context, refreshToken string) (newAccessToken string, newRefreshToken string, err error)) *AuthServiceMock {
|
||||
if mmRefresh.defaultExpectation != nil {
|
||||
mmRefresh.mock.t.Fatalf("Default expectation is already set for the AuthService.Refresh method")
|
||||
}
|
||||
@@ -1048,8 +1049,8 @@ func (mmRefresh *mAuthServiceMockRefresh) When(ctx context.Context, refreshToken
|
||||
}
|
||||
|
||||
// Then sets up AuthService.Refresh return parameters for the expectation previously defined by the When method
|
||||
func (e *AuthServiceMockRefreshExpectation) Then(s1 string, err error) *AuthServiceMock {
|
||||
e.results = &AuthServiceMockRefreshResults{s1, err}
|
||||
func (e *AuthServiceMockRefreshExpectation) Then(newAccessToken string, newRefreshToken string, err error) *AuthServiceMock {
|
||||
e.results = &AuthServiceMockRefreshResults{newAccessToken, newRefreshToken, err}
|
||||
return e.mock
|
||||
}
|
||||
|
||||
@@ -1075,7 +1076,7 @@ func (mmRefresh *mAuthServiceMockRefresh) invocationsDone() bool {
|
||||
}
|
||||
|
||||
// Refresh implements mm_service.AuthService
|
||||
func (mmRefresh *AuthServiceMock) Refresh(ctx context.Context, refreshToken string) (s1 string, err error) {
|
||||
func (mmRefresh *AuthServiceMock) Refresh(ctx context.Context, refreshToken string) (newAccessToken string, newRefreshToken string, err error) {
|
||||
mm_atomic.AddUint64(&mmRefresh.beforeRefreshCounter, 1)
|
||||
defer mm_atomic.AddUint64(&mmRefresh.afterRefreshCounter, 1)
|
||||
|
||||
@@ -1095,7 +1096,7 @@ func (mmRefresh *AuthServiceMock) Refresh(ctx context.Context, refreshToken stri
|
||||
for _, e := range mmRefresh.RefreshMock.expectations {
|
||||
if minimock.Equal(*e.params, mm_params) {
|
||||
mm_atomic.AddUint64(&e.Counter, 1)
|
||||
return e.results.s1, e.results.err
|
||||
return e.results.newAccessToken, e.results.newRefreshToken, e.results.err
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1127,7 +1128,7 @@ func (mmRefresh *AuthServiceMock) Refresh(ctx context.Context, refreshToken stri
|
||||
if mm_results == nil {
|
||||
mmRefresh.t.Fatal("No results are set for the AuthServiceMock.Refresh")
|
||||
}
|
||||
return (*mm_results).s1, (*mm_results).err
|
||||
return (*mm_results).newAccessToken, (*mm_results).newRefreshToken, (*mm_results).err
|
||||
}
|
||||
if mmRefresh.funcRefresh != nil {
|
||||
return mmRefresh.funcRefresh(ctx, refreshToken)
|
||||
|
||||
@@ -67,6 +67,13 @@ type SessionRepositoryMock struct {
|
||||
afterUpdateAccessTokenCounter uint64
|
||||
beforeUpdateAccessTokenCounter uint64
|
||||
UpdateAccessTokenMock mSessionRepositoryMockUpdateAccessToken
|
||||
|
||||
funcUpdateTokens func(ctx context.Context, oldRefreshToken string, newAccessToken string, newRefreshToken string) (err error)
|
||||
funcUpdateTokensOrigin string
|
||||
inspectFuncUpdateTokens func(ctx context.Context, oldRefreshToken string, newAccessToken string, newRefreshToken string)
|
||||
afterUpdateTokensCounter uint64
|
||||
beforeUpdateTokensCounter uint64
|
||||
UpdateTokensMock mSessionRepositoryMockUpdateTokens
|
||||
}
|
||||
|
||||
// NewSessionRepositoryMock returns a mock for mm_repository.SessionRepository
|
||||
@@ -98,6 +105,9 @@ func NewSessionRepositoryMock(t minimock.Tester) *SessionRepositoryMock {
|
||||
m.UpdateAccessTokenMock = mSessionRepositoryMockUpdateAccessToken{mock: m}
|
||||
m.UpdateAccessTokenMock.callArgs = []*SessionRepositoryMockUpdateAccessTokenParams{}
|
||||
|
||||
m.UpdateTokensMock = mSessionRepositoryMockUpdateTokens{mock: m}
|
||||
m.UpdateTokensMock.callArgs = []*SessionRepositoryMockUpdateTokensParams{}
|
||||
|
||||
t.Cleanup(m.MinimockFinish)
|
||||
|
||||
return m
|
||||
@@ -2500,6 +2510,410 @@ func (m *SessionRepositoryMock) MinimockUpdateAccessTokenInspect() {
|
||||
}
|
||||
}
|
||||
|
||||
type mSessionRepositoryMockUpdateTokens struct {
|
||||
optional bool
|
||||
mock *SessionRepositoryMock
|
||||
defaultExpectation *SessionRepositoryMockUpdateTokensExpectation
|
||||
expectations []*SessionRepositoryMockUpdateTokensExpectation
|
||||
|
||||
callArgs []*SessionRepositoryMockUpdateTokensParams
|
||||
mutex sync.RWMutex
|
||||
|
||||
expectedInvocations uint64
|
||||
expectedInvocationsOrigin string
|
||||
}
|
||||
|
||||
// SessionRepositoryMockUpdateTokensExpectation specifies expectation struct of the SessionRepository.UpdateTokens
|
||||
type SessionRepositoryMockUpdateTokensExpectation struct {
|
||||
mock *SessionRepositoryMock
|
||||
params *SessionRepositoryMockUpdateTokensParams
|
||||
paramPtrs *SessionRepositoryMockUpdateTokensParamPtrs
|
||||
expectationOrigins SessionRepositoryMockUpdateTokensExpectationOrigins
|
||||
results *SessionRepositoryMockUpdateTokensResults
|
||||
returnOrigin string
|
||||
Counter uint64
|
||||
}
|
||||
|
||||
// SessionRepositoryMockUpdateTokensParams contains parameters of the SessionRepository.UpdateTokens
|
||||
type SessionRepositoryMockUpdateTokensParams struct {
|
||||
ctx context.Context
|
||||
oldRefreshToken string
|
||||
newAccessToken string
|
||||
newRefreshToken string
|
||||
}
|
||||
|
||||
// SessionRepositoryMockUpdateTokensParamPtrs contains pointers to parameters of the SessionRepository.UpdateTokens
|
||||
type SessionRepositoryMockUpdateTokensParamPtrs struct {
|
||||
ctx *context.Context
|
||||
oldRefreshToken *string
|
||||
newAccessToken *string
|
||||
newRefreshToken *string
|
||||
}
|
||||
|
||||
// SessionRepositoryMockUpdateTokensResults contains results of the SessionRepository.UpdateTokens
|
||||
type SessionRepositoryMockUpdateTokensResults struct {
|
||||
err error
|
||||
}
|
||||
|
||||
// SessionRepositoryMockUpdateTokensOrigins contains origins of expectations of the SessionRepository.UpdateTokens
|
||||
type SessionRepositoryMockUpdateTokensExpectationOrigins struct {
|
||||
origin string
|
||||
originCtx string
|
||||
originOldRefreshToken string
|
||||
originNewAccessToken string
|
||||
originNewRefreshToken string
|
||||
}
|
||||
|
||||
// Marks this method to be optional. The default behavior of any method with Return() is '1 or more', meaning
|
||||
// the test will fail minimock's automatic final call check if the mocked method was not called at least once.
|
||||
// Optional() makes method check to work in '0 or more' mode.
|
||||
// It is NOT RECOMMENDED to use this option unless you really need it, as default behaviour helps to
|
||||
// catch the problems when the expected method call is totally skipped during test run.
|
||||
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) Optional() *mSessionRepositoryMockUpdateTokens {
|
||||
mmUpdateTokens.optional = true
|
||||
return mmUpdateTokens
|
||||
}
|
||||
|
||||
// Expect sets up expected params for SessionRepository.UpdateTokens
|
||||
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) Expect(ctx context.Context, oldRefreshToken string, newAccessToken string, newRefreshToken string) *mSessionRepositoryMockUpdateTokens {
|
||||
if mmUpdateTokens.mock.funcUpdateTokens != nil {
|
||||
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Set")
|
||||
}
|
||||
|
||||
if mmUpdateTokens.defaultExpectation == nil {
|
||||
mmUpdateTokens.defaultExpectation = &SessionRepositoryMockUpdateTokensExpectation{}
|
||||
}
|
||||
|
||||
if mmUpdateTokens.defaultExpectation.paramPtrs != nil {
|
||||
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by ExpectParams functions")
|
||||
}
|
||||
|
||||
mmUpdateTokens.defaultExpectation.params = &SessionRepositoryMockUpdateTokensParams{ctx, oldRefreshToken, newAccessToken, newRefreshToken}
|
||||
mmUpdateTokens.defaultExpectation.expectationOrigins.origin = minimock.CallerInfo(1)
|
||||
for _, e := range mmUpdateTokens.expectations {
|
||||
if minimock.Equal(e.params, mmUpdateTokens.defaultExpectation.params) {
|
||||
mmUpdateTokens.mock.t.Fatalf("Expectation set by When has same params: %#v", *mmUpdateTokens.defaultExpectation.params)
|
||||
}
|
||||
}
|
||||
|
||||
return mmUpdateTokens
|
||||
}
|
||||
|
||||
// ExpectCtxParam1 sets up expected param ctx for SessionRepository.UpdateTokens
|
||||
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) ExpectCtxParam1(ctx context.Context) *mSessionRepositoryMockUpdateTokens {
|
||||
if mmUpdateTokens.mock.funcUpdateTokens != nil {
|
||||
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Set")
|
||||
}
|
||||
|
||||
if mmUpdateTokens.defaultExpectation == nil {
|
||||
mmUpdateTokens.defaultExpectation = &SessionRepositoryMockUpdateTokensExpectation{}
|
||||
}
|
||||
|
||||
if mmUpdateTokens.defaultExpectation.params != nil {
|
||||
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Expect")
|
||||
}
|
||||
|
||||
if mmUpdateTokens.defaultExpectation.paramPtrs == nil {
|
||||
mmUpdateTokens.defaultExpectation.paramPtrs = &SessionRepositoryMockUpdateTokensParamPtrs{}
|
||||
}
|
||||
mmUpdateTokens.defaultExpectation.paramPtrs.ctx = &ctx
|
||||
mmUpdateTokens.defaultExpectation.expectationOrigins.originCtx = minimock.CallerInfo(1)
|
||||
|
||||
return mmUpdateTokens
|
||||
}
|
||||
|
||||
// ExpectOldRefreshTokenParam2 sets up expected param oldRefreshToken for SessionRepository.UpdateTokens
|
||||
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) ExpectOldRefreshTokenParam2(oldRefreshToken string) *mSessionRepositoryMockUpdateTokens {
|
||||
if mmUpdateTokens.mock.funcUpdateTokens != nil {
|
||||
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Set")
|
||||
}
|
||||
|
||||
if mmUpdateTokens.defaultExpectation == nil {
|
||||
mmUpdateTokens.defaultExpectation = &SessionRepositoryMockUpdateTokensExpectation{}
|
||||
}
|
||||
|
||||
if mmUpdateTokens.defaultExpectation.params != nil {
|
||||
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Expect")
|
||||
}
|
||||
|
||||
if mmUpdateTokens.defaultExpectation.paramPtrs == nil {
|
||||
mmUpdateTokens.defaultExpectation.paramPtrs = &SessionRepositoryMockUpdateTokensParamPtrs{}
|
||||
}
|
||||
mmUpdateTokens.defaultExpectation.paramPtrs.oldRefreshToken = &oldRefreshToken
|
||||
mmUpdateTokens.defaultExpectation.expectationOrigins.originOldRefreshToken = minimock.CallerInfo(1)
|
||||
|
||||
return mmUpdateTokens
|
||||
}
|
||||
|
||||
// ExpectNewAccessTokenParam3 sets up expected param newAccessToken for SessionRepository.UpdateTokens
|
||||
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) ExpectNewAccessTokenParam3(newAccessToken string) *mSessionRepositoryMockUpdateTokens {
|
||||
if mmUpdateTokens.mock.funcUpdateTokens != nil {
|
||||
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Set")
|
||||
}
|
||||
|
||||
if mmUpdateTokens.defaultExpectation == nil {
|
||||
mmUpdateTokens.defaultExpectation = &SessionRepositoryMockUpdateTokensExpectation{}
|
||||
}
|
||||
|
||||
if mmUpdateTokens.defaultExpectation.params != nil {
|
||||
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Expect")
|
||||
}
|
||||
|
||||
if mmUpdateTokens.defaultExpectation.paramPtrs == nil {
|
||||
mmUpdateTokens.defaultExpectation.paramPtrs = &SessionRepositoryMockUpdateTokensParamPtrs{}
|
||||
}
|
||||
mmUpdateTokens.defaultExpectation.paramPtrs.newAccessToken = &newAccessToken
|
||||
mmUpdateTokens.defaultExpectation.expectationOrigins.originNewAccessToken = minimock.CallerInfo(1)
|
||||
|
||||
return mmUpdateTokens
|
||||
}
|
||||
|
||||
// ExpectNewRefreshTokenParam4 sets up expected param newRefreshToken for SessionRepository.UpdateTokens
|
||||
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) ExpectNewRefreshTokenParam4(newRefreshToken string) *mSessionRepositoryMockUpdateTokens {
|
||||
if mmUpdateTokens.mock.funcUpdateTokens != nil {
|
||||
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Set")
|
||||
}
|
||||
|
||||
if mmUpdateTokens.defaultExpectation == nil {
|
||||
mmUpdateTokens.defaultExpectation = &SessionRepositoryMockUpdateTokensExpectation{}
|
||||
}
|
||||
|
||||
if mmUpdateTokens.defaultExpectation.params != nil {
|
||||
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Expect")
|
||||
}
|
||||
|
||||
if mmUpdateTokens.defaultExpectation.paramPtrs == nil {
|
||||
mmUpdateTokens.defaultExpectation.paramPtrs = &SessionRepositoryMockUpdateTokensParamPtrs{}
|
||||
}
|
||||
mmUpdateTokens.defaultExpectation.paramPtrs.newRefreshToken = &newRefreshToken
|
||||
mmUpdateTokens.defaultExpectation.expectationOrigins.originNewRefreshToken = minimock.CallerInfo(1)
|
||||
|
||||
return mmUpdateTokens
|
||||
}
|
||||
|
||||
// Inspect accepts an inspector function that has same arguments as the SessionRepository.UpdateTokens
|
||||
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) Inspect(f func(ctx context.Context, oldRefreshToken string, newAccessToken string, newRefreshToken string)) *mSessionRepositoryMockUpdateTokens {
|
||||
if mmUpdateTokens.mock.inspectFuncUpdateTokens != nil {
|
||||
mmUpdateTokens.mock.t.Fatalf("Inspect function is already set for SessionRepositoryMock.UpdateTokens")
|
||||
}
|
||||
|
||||
mmUpdateTokens.mock.inspectFuncUpdateTokens = f
|
||||
|
||||
return mmUpdateTokens
|
||||
}
|
||||
|
||||
// Return sets up results that will be returned by SessionRepository.UpdateTokens
|
||||
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) Return(err error) *SessionRepositoryMock {
|
||||
if mmUpdateTokens.mock.funcUpdateTokens != nil {
|
||||
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Set")
|
||||
}
|
||||
|
||||
if mmUpdateTokens.defaultExpectation == nil {
|
||||
mmUpdateTokens.defaultExpectation = &SessionRepositoryMockUpdateTokensExpectation{mock: mmUpdateTokens.mock}
|
||||
}
|
||||
mmUpdateTokens.defaultExpectation.results = &SessionRepositoryMockUpdateTokensResults{err}
|
||||
mmUpdateTokens.defaultExpectation.returnOrigin = minimock.CallerInfo(1)
|
||||
return mmUpdateTokens.mock
|
||||
}
|
||||
|
||||
// Set uses given function f to mock the SessionRepository.UpdateTokens method
|
||||
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) Set(f func(ctx context.Context, oldRefreshToken string, newAccessToken string, newRefreshToken string) (err error)) *SessionRepositoryMock {
|
||||
if mmUpdateTokens.defaultExpectation != nil {
|
||||
mmUpdateTokens.mock.t.Fatalf("Default expectation is already set for the SessionRepository.UpdateTokens method")
|
||||
}
|
||||
|
||||
if len(mmUpdateTokens.expectations) > 0 {
|
||||
mmUpdateTokens.mock.t.Fatalf("Some expectations are already set for the SessionRepository.UpdateTokens method")
|
||||
}
|
||||
|
||||
mmUpdateTokens.mock.funcUpdateTokens = f
|
||||
mmUpdateTokens.mock.funcUpdateTokensOrigin = minimock.CallerInfo(1)
|
||||
return mmUpdateTokens.mock
|
||||
}
|
||||
|
||||
// When sets expectation for the SessionRepository.UpdateTokens which will trigger the result defined by the following
|
||||
// Then helper
|
||||
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) When(ctx context.Context, oldRefreshToken string, newAccessToken string, newRefreshToken string) *SessionRepositoryMockUpdateTokensExpectation {
|
||||
if mmUpdateTokens.mock.funcUpdateTokens != nil {
|
||||
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Set")
|
||||
}
|
||||
|
||||
expectation := &SessionRepositoryMockUpdateTokensExpectation{
|
||||
mock: mmUpdateTokens.mock,
|
||||
params: &SessionRepositoryMockUpdateTokensParams{ctx, oldRefreshToken, newAccessToken, newRefreshToken},
|
||||
expectationOrigins: SessionRepositoryMockUpdateTokensExpectationOrigins{origin: minimock.CallerInfo(1)},
|
||||
}
|
||||
mmUpdateTokens.expectations = append(mmUpdateTokens.expectations, expectation)
|
||||
return expectation
|
||||
}
|
||||
|
||||
// Then sets up SessionRepository.UpdateTokens return parameters for the expectation previously defined by the When method
|
||||
func (e *SessionRepositoryMockUpdateTokensExpectation) Then(err error) *SessionRepositoryMock {
|
||||
e.results = &SessionRepositoryMockUpdateTokensResults{err}
|
||||
return e.mock
|
||||
}
|
||||
|
||||
// Times sets number of times SessionRepository.UpdateTokens should be invoked
|
||||
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) Times(n uint64) *mSessionRepositoryMockUpdateTokens {
|
||||
if n == 0 {
|
||||
mmUpdateTokens.mock.t.Fatalf("Times of SessionRepositoryMock.UpdateTokens mock can not be zero")
|
||||
}
|
||||
mm_atomic.StoreUint64(&mmUpdateTokens.expectedInvocations, n)
|
||||
mmUpdateTokens.expectedInvocationsOrigin = minimock.CallerInfo(1)
|
||||
return mmUpdateTokens
|
||||
}
|
||||
|
||||
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) invocationsDone() bool {
|
||||
if len(mmUpdateTokens.expectations) == 0 && mmUpdateTokens.defaultExpectation == nil && mmUpdateTokens.mock.funcUpdateTokens == nil {
|
||||
return true
|
||||
}
|
||||
|
||||
totalInvocations := mm_atomic.LoadUint64(&mmUpdateTokens.mock.afterUpdateTokensCounter)
|
||||
expectedInvocations := mm_atomic.LoadUint64(&mmUpdateTokens.expectedInvocations)
|
||||
|
||||
return totalInvocations > 0 && (expectedInvocations == 0 || expectedInvocations == totalInvocations)
|
||||
}
|
||||
|
||||
// UpdateTokens implements mm_repository.SessionRepository
|
||||
func (mmUpdateTokens *SessionRepositoryMock) UpdateTokens(ctx context.Context, oldRefreshToken string, newAccessToken string, newRefreshToken string) (err error) {
|
||||
mm_atomic.AddUint64(&mmUpdateTokens.beforeUpdateTokensCounter, 1)
|
||||
defer mm_atomic.AddUint64(&mmUpdateTokens.afterUpdateTokensCounter, 1)
|
||||
|
||||
mmUpdateTokens.t.Helper()
|
||||
|
||||
if mmUpdateTokens.inspectFuncUpdateTokens != nil {
|
||||
mmUpdateTokens.inspectFuncUpdateTokens(ctx, oldRefreshToken, newAccessToken, newRefreshToken)
|
||||
}
|
||||
|
||||
mm_params := SessionRepositoryMockUpdateTokensParams{ctx, oldRefreshToken, newAccessToken, newRefreshToken}
|
||||
|
||||
// Record call args
|
||||
mmUpdateTokens.UpdateTokensMock.mutex.Lock()
|
||||
mmUpdateTokens.UpdateTokensMock.callArgs = append(mmUpdateTokens.UpdateTokensMock.callArgs, &mm_params)
|
||||
mmUpdateTokens.UpdateTokensMock.mutex.Unlock()
|
||||
|
||||
for _, e := range mmUpdateTokens.UpdateTokensMock.expectations {
|
||||
if minimock.Equal(*e.params, mm_params) {
|
||||
mm_atomic.AddUint64(&e.Counter, 1)
|
||||
return e.results.err
|
||||
}
|
||||
}
|
||||
|
||||
if mmUpdateTokens.UpdateTokensMock.defaultExpectation != nil {
|
||||
mm_atomic.AddUint64(&mmUpdateTokens.UpdateTokensMock.defaultExpectation.Counter, 1)
|
||||
mm_want := mmUpdateTokens.UpdateTokensMock.defaultExpectation.params
|
||||
mm_want_ptrs := mmUpdateTokens.UpdateTokensMock.defaultExpectation.paramPtrs
|
||||
|
||||
mm_got := SessionRepositoryMockUpdateTokensParams{ctx, oldRefreshToken, newAccessToken, newRefreshToken}
|
||||
|
||||
if mm_want_ptrs != nil {
|
||||
|
||||
if mm_want_ptrs.ctx != nil && !minimock.Equal(*mm_want_ptrs.ctx, mm_got.ctx) {
|
||||
mmUpdateTokens.t.Errorf("SessionRepositoryMock.UpdateTokens got unexpected parameter ctx, expected at\n%s:\nwant: %#v\n got: %#v%s\n",
|
||||
mmUpdateTokens.UpdateTokensMock.defaultExpectation.expectationOrigins.originCtx, *mm_want_ptrs.ctx, mm_got.ctx, minimock.Diff(*mm_want_ptrs.ctx, mm_got.ctx))
|
||||
}
|
||||
|
||||
if mm_want_ptrs.oldRefreshToken != nil && !minimock.Equal(*mm_want_ptrs.oldRefreshToken, mm_got.oldRefreshToken) {
|
||||
mmUpdateTokens.t.Errorf("SessionRepositoryMock.UpdateTokens got unexpected parameter oldRefreshToken, expected at\n%s:\nwant: %#v\n got: %#v%s\n",
|
||||
mmUpdateTokens.UpdateTokensMock.defaultExpectation.expectationOrigins.originOldRefreshToken, *mm_want_ptrs.oldRefreshToken, mm_got.oldRefreshToken, minimock.Diff(*mm_want_ptrs.oldRefreshToken, mm_got.oldRefreshToken))
|
||||
}
|
||||
|
||||
if mm_want_ptrs.newAccessToken != nil && !minimock.Equal(*mm_want_ptrs.newAccessToken, mm_got.newAccessToken) {
|
||||
mmUpdateTokens.t.Errorf("SessionRepositoryMock.UpdateTokens got unexpected parameter newAccessToken, expected at\n%s:\nwant: %#v\n got: %#v%s\n",
|
||||
mmUpdateTokens.UpdateTokensMock.defaultExpectation.expectationOrigins.originNewAccessToken, *mm_want_ptrs.newAccessToken, mm_got.newAccessToken, minimock.Diff(*mm_want_ptrs.newAccessToken, mm_got.newAccessToken))
|
||||
}
|
||||
|
||||
if mm_want_ptrs.newRefreshToken != nil && !minimock.Equal(*mm_want_ptrs.newRefreshToken, mm_got.newRefreshToken) {
|
||||
mmUpdateTokens.t.Errorf("SessionRepositoryMock.UpdateTokens got unexpected parameter newRefreshToken, expected at\n%s:\nwant: %#v\n got: %#v%s\n",
|
||||
mmUpdateTokens.UpdateTokensMock.defaultExpectation.expectationOrigins.originNewRefreshToken, *mm_want_ptrs.newRefreshToken, mm_got.newRefreshToken, minimock.Diff(*mm_want_ptrs.newRefreshToken, mm_got.newRefreshToken))
|
||||
}
|
||||
|
||||
} else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) {
|
||||
mmUpdateTokens.t.Errorf("SessionRepositoryMock.UpdateTokens got unexpected parameters, expected at\n%s:\nwant: %#v\n got: %#v%s\n",
|
||||
mmUpdateTokens.UpdateTokensMock.defaultExpectation.expectationOrigins.origin, *mm_want, mm_got, minimock.Diff(*mm_want, mm_got))
|
||||
}
|
||||
|
||||
mm_results := mmUpdateTokens.UpdateTokensMock.defaultExpectation.results
|
||||
if mm_results == nil {
|
||||
mmUpdateTokens.t.Fatal("No results are set for the SessionRepositoryMock.UpdateTokens")
|
||||
}
|
||||
return (*mm_results).err
|
||||
}
|
||||
if mmUpdateTokens.funcUpdateTokens != nil {
|
||||
return mmUpdateTokens.funcUpdateTokens(ctx, oldRefreshToken, newAccessToken, newRefreshToken)
|
||||
}
|
||||
mmUpdateTokens.t.Fatalf("Unexpected call to SessionRepositoryMock.UpdateTokens. %v %v %v %v", ctx, oldRefreshToken, newAccessToken, newRefreshToken)
|
||||
return
|
||||
}
|
||||
|
||||
// UpdateTokensAfterCounter returns a count of finished SessionRepositoryMock.UpdateTokens invocations
|
||||
func (mmUpdateTokens *SessionRepositoryMock) UpdateTokensAfterCounter() uint64 {
|
||||
return mm_atomic.LoadUint64(&mmUpdateTokens.afterUpdateTokensCounter)
|
||||
}
|
||||
|
||||
// UpdateTokensBeforeCounter returns a count of SessionRepositoryMock.UpdateTokens invocations
|
||||
func (mmUpdateTokens *SessionRepositoryMock) UpdateTokensBeforeCounter() uint64 {
|
||||
return mm_atomic.LoadUint64(&mmUpdateTokens.beforeUpdateTokensCounter)
|
||||
}
|
||||
|
||||
// Calls returns a list of arguments used in each call to SessionRepositoryMock.UpdateTokens.
|
||||
// The list is in the same order as the calls were made (i.e. recent calls have a higher index)
|
||||
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) Calls() []*SessionRepositoryMockUpdateTokensParams {
|
||||
mmUpdateTokens.mutex.RLock()
|
||||
|
||||
argCopy := make([]*SessionRepositoryMockUpdateTokensParams, len(mmUpdateTokens.callArgs))
|
||||
copy(argCopy, mmUpdateTokens.callArgs)
|
||||
|
||||
mmUpdateTokens.mutex.RUnlock()
|
||||
|
||||
return argCopy
|
||||
}
|
||||
|
||||
// MinimockUpdateTokensDone returns true if the count of the UpdateTokens invocations corresponds
|
||||
// the number of defined expectations
|
||||
func (m *SessionRepositoryMock) MinimockUpdateTokensDone() bool {
|
||||
if m.UpdateTokensMock.optional {
|
||||
// Optional methods provide '0 or more' call count restriction.
|
||||
return true
|
||||
}
|
||||
|
||||
for _, e := range m.UpdateTokensMock.expectations {
|
||||
if mm_atomic.LoadUint64(&e.Counter) < 1 {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return m.UpdateTokensMock.invocationsDone()
|
||||
}
|
||||
|
||||
// MinimockUpdateTokensInspect logs each unmet expectation
|
||||
func (m *SessionRepositoryMock) MinimockUpdateTokensInspect() {
|
||||
for _, e := range m.UpdateTokensMock.expectations {
|
||||
if mm_atomic.LoadUint64(&e.Counter) < 1 {
|
||||
m.t.Errorf("Expected call to SessionRepositoryMock.UpdateTokens at\n%s with params: %#v", e.expectationOrigins.origin, *e.params)
|
||||
}
|
||||
}
|
||||
|
||||
afterUpdateTokensCounter := mm_atomic.LoadUint64(&m.afterUpdateTokensCounter)
|
||||
// if default expectation was set then invocations count should be greater than zero
|
||||
if m.UpdateTokensMock.defaultExpectation != nil && afterUpdateTokensCounter < 1 {
|
||||
if m.UpdateTokensMock.defaultExpectation.params == nil {
|
||||
m.t.Errorf("Expected call to SessionRepositoryMock.UpdateTokens at\n%s", m.UpdateTokensMock.defaultExpectation.returnOrigin)
|
||||
} else {
|
||||
m.t.Errorf("Expected call to SessionRepositoryMock.UpdateTokens at\n%s with params: %#v", m.UpdateTokensMock.defaultExpectation.expectationOrigins.origin, *m.UpdateTokensMock.defaultExpectation.params)
|
||||
}
|
||||
}
|
||||
// if func was set then invocations count should be greater than zero
|
||||
if m.funcUpdateTokens != nil && afterUpdateTokensCounter < 1 {
|
||||
m.t.Errorf("Expected call to SessionRepositoryMock.UpdateTokens at\n%s", m.funcUpdateTokensOrigin)
|
||||
}
|
||||
|
||||
if !m.UpdateTokensMock.invocationsDone() && afterUpdateTokensCounter > 0 {
|
||||
m.t.Errorf("Expected %d calls to SessionRepositoryMock.UpdateTokens at\n%s but found %d calls",
|
||||
mm_atomic.LoadUint64(&m.UpdateTokensMock.expectedInvocations), m.UpdateTokensMock.expectedInvocationsOrigin, afterUpdateTokensCounter)
|
||||
}
|
||||
}
|
||||
|
||||
// MinimockFinish checks that all mocked methods have been called the expected number of times
|
||||
func (m *SessionRepositoryMock) MinimockFinish() {
|
||||
m.finishOnce.Do(func() {
|
||||
@@ -2517,6 +2931,8 @@ func (m *SessionRepositoryMock) MinimockFinish() {
|
||||
m.MinimockRevokeByAccessTokenInspect()
|
||||
|
||||
m.MinimockUpdateAccessTokenInspect()
|
||||
|
||||
m.MinimockUpdateTokensInspect()
|
||||
}
|
||||
})
|
||||
}
|
||||
@@ -2546,5 +2962,6 @@ func (m *SessionRepositoryMock) minimockDone() bool {
|
||||
m.MinimockIsAccessTokenValidDone() &&
|
||||
m.MinimockRevokeDone() &&
|
||||
m.MinimockRevokeByAccessTokenDone() &&
|
||||
m.MinimockUpdateAccessTokenDone()
|
||||
m.MinimockUpdateAccessTokenDone() &&
|
||||
m.MinimockUpdateTokensDone()
|
||||
}
|
||||
|
||||
@@ -26,6 +26,7 @@ type SessionRepository interface {
|
||||
Create(ctx context.Context, session *model.Session) error
|
||||
FindByRefreshToken(ctx context.Context, token string) (*model.Session, error)
|
||||
UpdateAccessToken(ctx context.Context, refreshToken, newAccessToken string) error
|
||||
UpdateTokens(ctx context.Context, oldRefreshToken, newAccessToken, newRefreshToken string) error
|
||||
Revoke(ctx context.Context, refreshToken string) error
|
||||
RevokeByAccessToken(ctx context.Context, accessToken string) error
|
||||
IsAccessTokenValid(ctx context.Context, accessToken string) (bool, error)
|
||||
|
||||
@@ -96,6 +96,32 @@ func (r *sessionRepository) UpdateAccessToken(ctx context.Context, refreshToken,
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *sessionRepository) UpdateTokens(ctx context.Context, oldRefreshToken, newAccessToken, newRefreshToken string) error {
|
||||
query := r.qb.Update("sessions").
|
||||
Set("access_token", newAccessToken).
|
||||
Set("refresh_token", newRefreshToken).
|
||||
Where(sq.And{
|
||||
sq.Eq{"refresh_token": oldRefreshToken},
|
||||
sq.Expr("revoked_at IS NULL"),
|
||||
})
|
||||
|
||||
sqlQuery, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
||||
}
|
||||
|
||||
result, err := r.pool.Exec(ctx, sqlQuery, args...)
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to update tokens", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return errs.NewBusinessError(errs.RefreshInvalid, "refresh token is invalid or already used")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *sessionRepository) Revoke(ctx context.Context, refreshToken string) error {
|
||||
query := r.qb.Update("sessions").
|
||||
Set("revoked_at", time.Now()).
|
||||
|
||||
@@ -9,6 +9,7 @@ import (
|
||||
"git.techease.ru/Smart-search/smart-search-back/pkg/crypto"
|
||||
"git.techease.ru/Smart-search/smart-search-back/pkg/errors"
|
||||
"git.techease.ru/Smart-search/smart-search-back/pkg/jwt"
|
||||
"git.techease.ru/Smart-search/smart-search-back/pkg/validation"
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
@@ -40,8 +41,7 @@ func (s *authService) Login(ctx context.Context, email, password, ip, userAgent
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
passwordHash := crypto.PasswordHash(password)
|
||||
if user.PasswordHash != passwordHash {
|
||||
if !crypto.PasswordVerify(password, user.PasswordHash) {
|
||||
return "", "", errors.NewBusinessError(errors.AuthInvalidCredentials, "Invalid email or password")
|
||||
}
|
||||
|
||||
@@ -71,22 +71,27 @@ func (s *authService) Login(ctx context.Context, email, password, ip, userAgent
|
||||
return accessToken, refreshToken, nil
|
||||
}
|
||||
|
||||
func (s *authService) Refresh(ctx context.Context, refreshToken string) (string, error) {
|
||||
func (s *authService) Refresh(ctx context.Context, refreshToken string) (string, string, error) {
|
||||
session, err := s.sessionRepo.FindByRefreshToken(ctx, refreshToken)
|
||||
if err != nil {
|
||||
return "", err
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
newAccessToken, err := jwt.GenerateAccessToken(session.UserID, s.jwtSecret)
|
||||
if err != nil {
|
||||
return "", errors.NewInternalError(errors.InternalError, "failed to generate access token", err)
|
||||
return "", "", errors.NewInternalError(errors.InternalError, "failed to generate access token", err)
|
||||
}
|
||||
|
||||
if err := s.sessionRepo.UpdateAccessToken(ctx, refreshToken, newAccessToken); err != nil {
|
||||
return "", err
|
||||
newRefreshToken, err := jwt.GenerateRefreshToken(session.UserID, s.jwtSecret)
|
||||
if err != nil {
|
||||
return "", "", errors.NewInternalError(errors.InternalError, "failed to generate refresh token", err)
|
||||
}
|
||||
|
||||
return newAccessToken, nil
|
||||
if err := s.sessionRepo.UpdateTokens(ctx, refreshToken, newAccessToken, newRefreshToken); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
return newAccessToken, newRefreshToken, nil
|
||||
}
|
||||
|
||||
func (s *authService) Validate(ctx context.Context, accessToken string) (int, error) {
|
||||
@@ -121,6 +126,10 @@ func (s *authService) Logout(ctx context.Context, accessToken string) error {
|
||||
}
|
||||
|
||||
func (s *authService) Register(ctx context.Context, email, password, name, phone string, inviteCode int64, ip, userAgent string) (accessToken, refreshToken string, err error) {
|
||||
if err := validation.ValidateRegistration(email, password, name, phone); err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
|
||||
_, err = s.inviteRepo.FindActiveByCode(ctx, inviteCode)
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
|
||||
@@ -11,7 +11,7 @@ import (
|
||||
type AuthService interface {
|
||||
Register(ctx context.Context, email, password, name, phone string, inviteCode int64, ip, userAgent string) (accessToken, refreshToken string, err error)
|
||||
Login(ctx context.Context, email, password, ip, userAgent string) (accessToken, refreshToken string, err error)
|
||||
Refresh(ctx context.Context, refreshToken string) (string, error)
|
||||
Refresh(ctx context.Context, refreshToken string) (newAccessToken, newRefreshToken string, err error)
|
||||
Validate(ctx context.Context, accessToken string) (int, error)
|
||||
Logout(ctx context.Context, accessToken string) error
|
||||
}
|
||||
|
||||
@@ -10,6 +10,7 @@ import (
|
||||
"git.techease.ru/Smart-search/smart-search-back/internal/repository"
|
||||
"git.techease.ru/Smart-search/smart-search-back/pkg/errors"
|
||||
"git.techease.ru/Smart-search/smart-search-back/pkg/fileparser"
|
||||
"git.techease.ru/Smart-search/smart-search-back/pkg/validation"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
@@ -45,6 +46,14 @@ func NewRequestService(
|
||||
}
|
||||
|
||||
func (s *requestService) CreateTZ(ctx context.Context, userID int, requestTxt string, fileData []byte, fileName string) (uuid.UUID, string, error) {
|
||||
if err := validation.ValidateRequestTxt(requestTxt); err != nil {
|
||||
return uuid.Nil, "", err
|
||||
}
|
||||
|
||||
if err := validation.ValidateFileSize(len(fileData)); err != nil {
|
||||
return uuid.Nil, "", err
|
||||
}
|
||||
|
||||
combinedText := requestTxt
|
||||
|
||||
if len(fileData) > 0 && fileName != "" {
|
||||
|
||||
@@ -267,22 +267,24 @@ func (s *Suite) TestAuthService_Login_EmptyIPAndUserAgent() {
|
||||
func (s *Suite) TestAuthService_Refresh_Success() {
|
||||
session := createTestSession(1)
|
||||
s.sessionRepo.FindByRefreshTokenMock.Return(session, nil)
|
||||
s.sessionRepo.UpdateAccessTokenMock.Return(nil)
|
||||
s.sessionRepo.UpdateTokensMock.Return(nil)
|
||||
|
||||
accessToken, err := s.authService.Refresh(s.ctx, "test-refresh-token")
|
||||
accessToken, refreshToken, err := s.authService.Refresh(s.ctx, "test-refresh-token")
|
||||
|
||||
s.NoError(err)
|
||||
s.NotEmpty(accessToken)
|
||||
s.NotEmpty(refreshToken)
|
||||
}
|
||||
|
||||
func (s *Suite) TestAuthService_Refresh_RefreshInvalid() {
|
||||
err := apperrors.NewBusinessError(apperrors.RefreshInvalid, "refresh token is invalid or expired")
|
||||
s.sessionRepo.FindByRefreshTokenMock.Return(nil, err)
|
||||
|
||||
accessToken, refreshErr := s.authService.Refresh(s.ctx, "invalid-token")
|
||||
accessToken, refreshToken, refreshErr := s.authService.Refresh(s.ctx, "invalid-token")
|
||||
|
||||
s.Error(refreshErr)
|
||||
s.Empty(accessToken)
|
||||
s.Empty(refreshToken)
|
||||
|
||||
var appErr *apperrors.AppError
|
||||
s.True(errors.As(refreshErr, &appErr))
|
||||
@@ -293,10 +295,11 @@ func (s *Suite) TestAuthService_Refresh_DatabaseError_OnFindSession() {
|
||||
dbErr := apperrors.NewInternalError(apperrors.DatabaseError, "failed to find session", nil)
|
||||
s.sessionRepo.FindByRefreshTokenMock.Return(nil, dbErr)
|
||||
|
||||
accessToken, err := s.authService.Refresh(s.ctx, "test-refresh-token")
|
||||
accessToken, refreshToken, err := s.authService.Refresh(s.ctx, "test-refresh-token")
|
||||
|
||||
s.Error(err)
|
||||
s.Empty(accessToken)
|
||||
s.Empty(refreshToken)
|
||||
|
||||
var appErr *apperrors.AppError
|
||||
s.True(errors.As(err, &appErr))
|
||||
@@ -307,34 +310,37 @@ func (s *Suite) TestAuthService_Refresh_EmptyToken() {
|
||||
err := apperrors.NewBusinessError(apperrors.RefreshInvalid, "refresh token is invalid or expired")
|
||||
s.sessionRepo.FindByRefreshTokenMock.Return(nil, err)
|
||||
|
||||
accessToken, refreshErr := s.authService.Refresh(s.ctx, "")
|
||||
accessToken, refreshToken, refreshErr := s.authService.Refresh(s.ctx, "")
|
||||
|
||||
s.Error(refreshErr)
|
||||
s.Empty(accessToken)
|
||||
s.Empty(refreshToken)
|
||||
}
|
||||
|
||||
func (s *Suite) TestAuthService_Refresh_UpdateAccessTokenError() {
|
||||
func (s *Suite) TestAuthService_Refresh_UpdateTokensError() {
|
||||
session := createTestSession(1)
|
||||
s.sessionRepo.FindByRefreshTokenMock.Return(session, nil)
|
||||
|
||||
dbErr := apperrors.NewInternalError(apperrors.DatabaseError, "failed to update access token", nil)
|
||||
s.sessionRepo.UpdateAccessTokenMock.Return(dbErr)
|
||||
dbErr := apperrors.NewInternalError(apperrors.DatabaseError, "failed to update tokens", nil)
|
||||
s.sessionRepo.UpdateTokensMock.Return(dbErr)
|
||||
|
||||
accessToken, err := s.authService.Refresh(s.ctx, "test-refresh-token")
|
||||
accessToken, refreshToken, err := s.authService.Refresh(s.ctx, "test-refresh-token")
|
||||
|
||||
s.Error(err)
|
||||
s.Empty(accessToken)
|
||||
s.Empty(refreshToken)
|
||||
}
|
||||
|
||||
func (s *Suite) TestAuthService_Refresh_UserIDZero() {
|
||||
session := createTestSession(0)
|
||||
s.sessionRepo.FindByRefreshTokenMock.Return(session, nil)
|
||||
s.sessionRepo.UpdateAccessTokenMock.Return(nil)
|
||||
s.sessionRepo.UpdateTokensMock.Return(nil)
|
||||
|
||||
accessToken, err := s.authService.Refresh(s.ctx, "test-refresh-token")
|
||||
accessToken, refreshToken, err := s.authService.Refresh(s.ctx, "test-refresh-token")
|
||||
|
||||
s.NoError(err)
|
||||
s.NotEmpty(accessToken)
|
||||
s.NotEmpty(refreshToken)
|
||||
}
|
||||
|
||||
func (s *Suite) TestAuthService_Validate_Success() {
|
||||
|
||||
@@ -6,11 +6,12 @@ import (
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/hex"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
)
|
||||
|
||||
type Crypto struct {
|
||||
@@ -31,10 +32,19 @@ func (c *Crypto) EmailHash(email string) string {
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
}
|
||||
|
||||
const bcryptCost = 12
|
||||
|
||||
func PasswordHash(password string) string {
|
||||
h := sha512.New()
|
||||
h.Write([]byte(password))
|
||||
return hex.EncodeToString(h.Sum(nil))
|
||||
hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return string(hash)
|
||||
}
|
||||
|
||||
func PasswordVerify(password, hash string) bool {
|
||||
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
|
||||
return err == nil
|
||||
}
|
||||
|
||||
func (c *Crypto) getKey() []byte {
|
||||
|
||||
@@ -15,6 +15,13 @@ const (
|
||||
UnsupportedFileFormat = "UNSUPPORTED_FILE_FORMAT"
|
||||
FileProcessingError = "FILE_PROCESSING_ERROR"
|
||||
|
||||
ValidationInvalidEmail = "VALIDATION_INVALID_EMAIL"
|
||||
ValidationInvalidPassword = "VALIDATION_INVALID_PASSWORD"
|
||||
ValidationInvalidPhone = "VALIDATION_INVALID_PHONE"
|
||||
ValidationInvalidName = "VALIDATION_INVALID_NAME"
|
||||
ValidationFileTooLarge = "VALIDATION_FILE_TOO_LARGE"
|
||||
ValidationRequestTooLong = "VALIDATION_REQUEST_TOO_LONG"
|
||||
|
||||
DatabaseError = "DATABASE_ERROR"
|
||||
EncryptionError = "ENCRYPTION_ERROR"
|
||||
AIAPIError = "AI_API_ERROR"
|
||||
|
||||
156
pkg/validation/validation.go
Normal file
156
pkg/validation/validation.go
Normal file
@@ -0,0 +1,156 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"net/mail"
|
||||
"regexp"
|
||||
"strings"
|
||||
"unicode"
|
||||
|
||||
"git.techease.ru/Smart-search/smart-search-back/pkg/errors"
|
||||
)
|
||||
|
||||
const (
|
||||
MinPasswordLength = 8
|
||||
MaxPasswordLength = 128
|
||||
MaxEmailLength = 254
|
||||
MaxNameLength = 100
|
||||
MaxPhoneLength = 20
|
||||
MaxRequestTxtLen = 50000
|
||||
MaxFileSizeBytes = 10 * 1024 * 1024
|
||||
)
|
||||
|
||||
var (
|
||||
phoneRegex = regexp.MustCompile(`^\+?[1-9]\d{6,14}$`)
|
||||
)
|
||||
|
||||
func ValidateEmail(email string) error {
|
||||
if email == "" {
|
||||
return errors.NewBusinessError(errors.ValidationInvalidEmail, "email is required")
|
||||
}
|
||||
|
||||
if len(email) > MaxEmailLength {
|
||||
return errors.NewBusinessError(errors.ValidationInvalidEmail, "email is too long")
|
||||
}
|
||||
|
||||
_, err := mail.ParseAddress(email)
|
||||
if err != nil {
|
||||
return errors.NewBusinessError(errors.ValidationInvalidEmail, "invalid email format")
|
||||
}
|
||||
|
||||
parts := strings.Split(email, "@")
|
||||
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
|
||||
return errors.NewBusinessError(errors.ValidationInvalidEmail, "invalid email format")
|
||||
}
|
||||
|
||||
if strings.Contains(parts[1], "..") {
|
||||
return errors.NewBusinessError(errors.ValidationInvalidEmail, "invalid email format")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidatePassword(password string) error {
|
||||
if password == "" {
|
||||
return errors.NewBusinessError(errors.ValidationInvalidPassword, "password is required")
|
||||
}
|
||||
|
||||
if len(password) < MinPasswordLength {
|
||||
return errors.NewBusinessError(errors.ValidationInvalidPassword, "password must be at least 8 characters")
|
||||
}
|
||||
|
||||
if len(password) > MaxPasswordLength {
|
||||
return errors.NewBusinessError(errors.ValidationInvalidPassword, "password is too long")
|
||||
}
|
||||
|
||||
var hasUpper, hasLower, hasDigit bool
|
||||
for _, c := range password {
|
||||
switch {
|
||||
case unicode.IsUpper(c):
|
||||
hasUpper = true
|
||||
case unicode.IsLower(c):
|
||||
hasLower = true
|
||||
case unicode.IsDigit(c):
|
||||
hasDigit = true
|
||||
}
|
||||
}
|
||||
|
||||
if !hasUpper || !hasLower || !hasDigit {
|
||||
return errors.NewBusinessError(errors.ValidationInvalidPassword, "password must contain uppercase, lowercase and digit")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidatePhone(phone string) error {
|
||||
if phone == "" {
|
||||
return errors.NewBusinessError(errors.ValidationInvalidPhone, "phone is required")
|
||||
}
|
||||
|
||||
if len(phone) > MaxPhoneLength {
|
||||
return errors.NewBusinessError(errors.ValidationInvalidPhone, "phone is too long")
|
||||
}
|
||||
|
||||
cleaned := strings.ReplaceAll(phone, " ", "")
|
||||
cleaned = strings.ReplaceAll(cleaned, "-", "")
|
||||
cleaned = strings.ReplaceAll(cleaned, "(", "")
|
||||
cleaned = strings.ReplaceAll(cleaned, ")", "")
|
||||
|
||||
if !phoneRegex.MatchString(cleaned) {
|
||||
return errors.NewBusinessError(errors.ValidationInvalidPhone, "invalid phone format")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateName(name string) error {
|
||||
if name == "" {
|
||||
return errors.NewBusinessError(errors.ValidationInvalidName, "name is required")
|
||||
}
|
||||
|
||||
if len(name) > MaxNameLength {
|
||||
return errors.NewBusinessError(errors.ValidationInvalidName, "name is too long")
|
||||
}
|
||||
|
||||
trimmed := strings.TrimSpace(name)
|
||||
if trimmed == "" {
|
||||
return errors.NewBusinessError(errors.ValidationInvalidName, "name cannot be only whitespace")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateRequestTxt(txt string) error {
|
||||
if len(txt) > MaxRequestTxtLen {
|
||||
return errors.NewBusinessError(errors.ValidationRequestTooLong, "request text exceeds 50000 characters limit")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateFileSize(size int) error {
|
||||
if size > MaxFileSizeBytes {
|
||||
return errors.NewBusinessError(errors.ValidationFileTooLarge, "file size exceeds 10MB limit")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func ValidateRegistration(email, password, name, phone string) error {
|
||||
if err := ValidateEmail(email); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ValidatePassword(password); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ValidateName(name); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := ValidatePhone(phone); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
222
pkg/validation/validation_test.go
Normal file
222
pkg/validation/validation_test.go
Normal file
@@ -0,0 +1,222 @@
|
||||
package validation
|
||||
|
||||
import (
|
||||
"strings"
|
||||
"testing"
|
||||
|
||||
"git.techease.ru/Smart-search/smart-search-back/pkg/errors"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestValidateEmail(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
wantErr bool
|
||||
wantCode string
|
||||
}{
|
||||
{"valid email", "test@example.com", false, ""},
|
||||
{"valid with subdomain", "test@sub.example.com", false, ""},
|
||||
{"valid with plus", "test+tag@example.com", false, ""},
|
||||
{"valid with dots", "test.name@example.com", false, ""},
|
||||
{"empty", "", true, errors.ValidationInvalidEmail},
|
||||
{"no at sign", "testexample.com", true, errors.ValidationInvalidEmail},
|
||||
{"no domain", "test@", true, errors.ValidationInvalidEmail},
|
||||
{"no local part", "@example.com", true, errors.ValidationInvalidEmail},
|
||||
{"double dots in domain", "test@example..com", true, errors.ValidationInvalidEmail},
|
||||
{"too long", strings.Repeat("a", 255) + "@example.com", true, errors.ValidationInvalidEmail},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateEmail(tt.email)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
assert.Equal(t, tt.wantCode, appErr.Code)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePassword(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
password string
|
||||
wantErr bool
|
||||
wantCode string
|
||||
}{
|
||||
{"valid password", "Abcd1234", false, ""},
|
||||
{"valid with special chars", "Abcd1234!", false, ""},
|
||||
{"empty", "", true, errors.ValidationInvalidPassword},
|
||||
{"too short", "Ab1", true, errors.ValidationInvalidPassword},
|
||||
{"no uppercase", "abcd1234", true, errors.ValidationInvalidPassword},
|
||||
{"no lowercase", "ABCD1234", true, errors.ValidationInvalidPassword},
|
||||
{"no digit", "Abcdefgh", true, errors.ValidationInvalidPassword},
|
||||
{"only digits", "12345678", true, errors.ValidationInvalidPassword},
|
||||
{"only lowercase", "abcdefgh", true, errors.ValidationInvalidPassword},
|
||||
{"only uppercase", "ABCDEFGH", true, errors.ValidationInvalidPassword},
|
||||
{"too long", strings.Repeat("Aa1", 50), true, errors.ValidationInvalidPassword},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidatePassword(tt.password)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
assert.Equal(t, tt.wantCode, appErr.Code)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidatePhone(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
phone string
|
||||
wantErr bool
|
||||
wantCode string
|
||||
}{
|
||||
{"valid international", "+1234567890", false, ""},
|
||||
{"valid with country code", "+79123456789", false, ""},
|
||||
{"valid without plus", "1234567890", false, ""},
|
||||
{"empty", "", true, errors.ValidationInvalidPhone},
|
||||
{"too short", "123", true, errors.ValidationInvalidPhone},
|
||||
{"letters", "abcdefgh", true, errors.ValidationInvalidPhone},
|
||||
{"too long", "+123456789012345678901", true, errors.ValidationInvalidPhone},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidatePhone(tt.phone)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
assert.Equal(t, tt.wantCode, appErr.Code)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateName(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
value string
|
||||
wantErr bool
|
||||
wantCode string
|
||||
}{
|
||||
{"valid name", "John Doe", false, ""},
|
||||
{"valid cyrillic", "Иван Иванов", false, ""},
|
||||
{"empty", "", true, errors.ValidationInvalidName},
|
||||
{"only whitespace", " ", true, errors.ValidationInvalidName},
|
||||
{"too long", strings.Repeat("a", 101), true, errors.ValidationInvalidName},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateName(tt.value)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
assert.Equal(t, tt.wantCode, appErr.Code)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRequestTxt(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
txt string
|
||||
wantErr bool
|
||||
wantCode string
|
||||
}{
|
||||
{"valid short", "Test request", false, ""},
|
||||
{"empty is valid", "", false, ""},
|
||||
{"max length", strings.Repeat("a", MaxRequestTxtLen), false, ""},
|
||||
{"too long", strings.Repeat("a", MaxRequestTxtLen+1), true, errors.ValidationRequestTooLong},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateRequestTxt(tt.txt)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
assert.Equal(t, tt.wantCode, appErr.Code)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateFileSize(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
size int
|
||||
wantErr bool
|
||||
wantCode string
|
||||
}{
|
||||
{"zero", 0, false, ""},
|
||||
{"small file", 1024, false, ""},
|
||||
{"max size", MaxFileSizeBytes, false, ""},
|
||||
{"too large", MaxFileSizeBytes + 1, true, errors.ValidationFileTooLarge},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateFileSize(tt.size)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
if appErr, ok := err.(*errors.AppError); ok {
|
||||
assert.Equal(t, tt.wantCode, appErr.Code)
|
||||
}
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestValidateRegistration(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
email string
|
||||
password string
|
||||
userName string
|
||||
phone string
|
||||
wantErr bool
|
||||
}{
|
||||
{"valid", "test@example.com", "Abcd1234", "John Doe", "+1234567890", false},
|
||||
{"invalid email", "invalid", "Abcd1234", "John Doe", "+1234567890", true},
|
||||
{"invalid password", "test@example.com", "weak", "John Doe", "+1234567890", true},
|
||||
{"invalid name", "test@example.com", "Abcd1234", "", "+1234567890", true},
|
||||
{"invalid phone", "test@example.com", "Abcd1234", "John Doe", "abc", true},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
err := ValidateRegistration(tt.email, tt.password, tt.userName, tt.phone)
|
||||
if tt.wantErr {
|
||||
assert.Error(t, err)
|
||||
} else {
|
||||
assert.NoError(t, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
@@ -154,6 +154,7 @@ func (s *IntegrationSuite) createTestUser(email, password string) {
|
||||
|
||||
emailHash := cryptoHelper.EmailHash(email)
|
||||
passwordHash := crypto.PasswordHash(password)
|
||||
s.Require().NotEmpty(passwordHash, "password hash should not be empty")
|
||||
|
||||
query := `
|
||||
INSERT INTO users (email, email_hash, password_hash, phone, user_name, company_name, balance, payment_status, invites_issued, invites_limit)
|
||||
|
||||
1206
tests/security_test.go
Normal file
1206
tests/security_test.go
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user