add service
All checks were successful
Deploy Smart Search Backend / deploy (push) Successful in 1m47s

This commit is contained in:
vallyenfail
2026-01-20 19:02:06 +03:00
parent f8db0fd9e6
commit 8b9554720d
15 changed files with 2109 additions and 38 deletions

View File

@@ -49,14 +49,14 @@ func (h *AuthHandler) Login(ctx context.Context, req *pb.LoginRequest) (*pb.Logi
} }
func (h *AuthHandler) Refresh(ctx context.Context, req *pb.RefreshRequest) (*pb.RefreshResponse, error) { func (h *AuthHandler) Refresh(ctx context.Context, req *pb.RefreshRequest) (*pb.RefreshResponse, error) {
accessToken, err := h.authService.Refresh(ctx, req.RefreshToken) accessToken, refreshToken, err := h.authService.Refresh(ctx, req.RefreshToken)
if err != nil { if err != nil {
return nil, errors.ToGRPCError(err, h.logger, "AuthService.Refresh") return nil, errors.ToGRPCError(err, h.logger, "AuthService.Refresh")
} }
return &pb.RefreshResponse{ return &pb.RefreshResponse{
AccessToken: accessToken, AccessToken: accessToken,
RefreshToken: req.RefreshToken, RefreshToken: refreshToken,
}, nil }, nil
} }

View File

@@ -32,7 +32,7 @@ type AuthServiceMock struct {
beforeLogoutCounter uint64 beforeLogoutCounter uint64
LogoutMock mAuthServiceMockLogout LogoutMock mAuthServiceMockLogout
funcRefresh func(ctx context.Context, refreshToken string) (s1 string, err error) funcRefresh func(ctx context.Context, refreshToken string) (newAccessToken string, newRefreshToken string, err error)
funcRefreshOrigin string funcRefreshOrigin string
inspectFuncRefresh func(ctx context.Context, refreshToken string) inspectFuncRefresh func(ctx context.Context, refreshToken string)
afterRefreshCounter uint64 afterRefreshCounter uint64
@@ -899,7 +899,8 @@ type AuthServiceMockRefreshParamPtrs struct {
// AuthServiceMockRefreshResults contains results of the AuthService.Refresh // AuthServiceMockRefreshResults contains results of the AuthService.Refresh
type AuthServiceMockRefreshResults struct { type AuthServiceMockRefreshResults struct {
s1 string newAccessToken string
newRefreshToken string
err error err error
} }
@@ -1003,7 +1004,7 @@ func (mmRefresh *mAuthServiceMockRefresh) Inspect(f func(ctx context.Context, re
} }
// Return sets up results that will be returned by AuthService.Refresh // Return sets up results that will be returned by AuthService.Refresh
func (mmRefresh *mAuthServiceMockRefresh) Return(s1 string, err error) *AuthServiceMock { func (mmRefresh *mAuthServiceMockRefresh) Return(newAccessToken string, newRefreshToken string, err error) *AuthServiceMock {
if mmRefresh.mock.funcRefresh != nil { if mmRefresh.mock.funcRefresh != nil {
mmRefresh.mock.t.Fatalf("AuthServiceMock.Refresh mock is already set by Set") mmRefresh.mock.t.Fatalf("AuthServiceMock.Refresh mock is already set by Set")
} }
@@ -1011,13 +1012,13 @@ func (mmRefresh *mAuthServiceMockRefresh) Return(s1 string, err error) *AuthServ
if mmRefresh.defaultExpectation == nil { if mmRefresh.defaultExpectation == nil {
mmRefresh.defaultExpectation = &AuthServiceMockRefreshExpectation{mock: mmRefresh.mock} mmRefresh.defaultExpectation = &AuthServiceMockRefreshExpectation{mock: mmRefresh.mock}
} }
mmRefresh.defaultExpectation.results = &AuthServiceMockRefreshResults{s1, err} mmRefresh.defaultExpectation.results = &AuthServiceMockRefreshResults{newAccessToken, newRefreshToken, err}
mmRefresh.defaultExpectation.returnOrigin = minimock.CallerInfo(1) mmRefresh.defaultExpectation.returnOrigin = minimock.CallerInfo(1)
return mmRefresh.mock return mmRefresh.mock
} }
// Set uses given function f to mock the AuthService.Refresh method // Set uses given function f to mock the AuthService.Refresh method
func (mmRefresh *mAuthServiceMockRefresh) Set(f func(ctx context.Context, refreshToken string) (s1 string, err error)) *AuthServiceMock { func (mmRefresh *mAuthServiceMockRefresh) Set(f func(ctx context.Context, refreshToken string) (newAccessToken string, newRefreshToken string, err error)) *AuthServiceMock {
if mmRefresh.defaultExpectation != nil { if mmRefresh.defaultExpectation != nil {
mmRefresh.mock.t.Fatalf("Default expectation is already set for the AuthService.Refresh method") mmRefresh.mock.t.Fatalf("Default expectation is already set for the AuthService.Refresh method")
} }
@@ -1048,8 +1049,8 @@ func (mmRefresh *mAuthServiceMockRefresh) When(ctx context.Context, refreshToken
} }
// Then sets up AuthService.Refresh return parameters for the expectation previously defined by the When method // Then sets up AuthService.Refresh return parameters for the expectation previously defined by the When method
func (e *AuthServiceMockRefreshExpectation) Then(s1 string, err error) *AuthServiceMock { func (e *AuthServiceMockRefreshExpectation) Then(newAccessToken string, newRefreshToken string, err error) *AuthServiceMock {
e.results = &AuthServiceMockRefreshResults{s1, err} e.results = &AuthServiceMockRefreshResults{newAccessToken, newRefreshToken, err}
return e.mock return e.mock
} }
@@ -1075,7 +1076,7 @@ func (mmRefresh *mAuthServiceMockRefresh) invocationsDone() bool {
} }
// Refresh implements mm_service.AuthService // Refresh implements mm_service.AuthService
func (mmRefresh *AuthServiceMock) Refresh(ctx context.Context, refreshToken string) (s1 string, err error) { func (mmRefresh *AuthServiceMock) Refresh(ctx context.Context, refreshToken string) (newAccessToken string, newRefreshToken string, err error) {
mm_atomic.AddUint64(&mmRefresh.beforeRefreshCounter, 1) mm_atomic.AddUint64(&mmRefresh.beforeRefreshCounter, 1)
defer mm_atomic.AddUint64(&mmRefresh.afterRefreshCounter, 1) defer mm_atomic.AddUint64(&mmRefresh.afterRefreshCounter, 1)
@@ -1095,7 +1096,7 @@ func (mmRefresh *AuthServiceMock) Refresh(ctx context.Context, refreshToken stri
for _, e := range mmRefresh.RefreshMock.expectations { for _, e := range mmRefresh.RefreshMock.expectations {
if minimock.Equal(*e.params, mm_params) { if minimock.Equal(*e.params, mm_params) {
mm_atomic.AddUint64(&e.Counter, 1) mm_atomic.AddUint64(&e.Counter, 1)
return e.results.s1, e.results.err return e.results.newAccessToken, e.results.newRefreshToken, e.results.err
} }
} }
@@ -1127,7 +1128,7 @@ func (mmRefresh *AuthServiceMock) Refresh(ctx context.Context, refreshToken stri
if mm_results == nil { if mm_results == nil {
mmRefresh.t.Fatal("No results are set for the AuthServiceMock.Refresh") mmRefresh.t.Fatal("No results are set for the AuthServiceMock.Refresh")
} }
return (*mm_results).s1, (*mm_results).err return (*mm_results).newAccessToken, (*mm_results).newRefreshToken, (*mm_results).err
} }
if mmRefresh.funcRefresh != nil { if mmRefresh.funcRefresh != nil {
return mmRefresh.funcRefresh(ctx, refreshToken) return mmRefresh.funcRefresh(ctx, refreshToken)

View File

@@ -67,6 +67,13 @@ type SessionRepositoryMock struct {
afterUpdateAccessTokenCounter uint64 afterUpdateAccessTokenCounter uint64
beforeUpdateAccessTokenCounter uint64 beforeUpdateAccessTokenCounter uint64
UpdateAccessTokenMock mSessionRepositoryMockUpdateAccessToken UpdateAccessTokenMock mSessionRepositoryMockUpdateAccessToken
funcUpdateTokens func(ctx context.Context, oldRefreshToken string, newAccessToken string, newRefreshToken string) (err error)
funcUpdateTokensOrigin string
inspectFuncUpdateTokens func(ctx context.Context, oldRefreshToken string, newAccessToken string, newRefreshToken string)
afterUpdateTokensCounter uint64
beforeUpdateTokensCounter uint64
UpdateTokensMock mSessionRepositoryMockUpdateTokens
} }
// NewSessionRepositoryMock returns a mock for mm_repository.SessionRepository // NewSessionRepositoryMock returns a mock for mm_repository.SessionRepository
@@ -98,6 +105,9 @@ func NewSessionRepositoryMock(t minimock.Tester) *SessionRepositoryMock {
m.UpdateAccessTokenMock = mSessionRepositoryMockUpdateAccessToken{mock: m} m.UpdateAccessTokenMock = mSessionRepositoryMockUpdateAccessToken{mock: m}
m.UpdateAccessTokenMock.callArgs = []*SessionRepositoryMockUpdateAccessTokenParams{} m.UpdateAccessTokenMock.callArgs = []*SessionRepositoryMockUpdateAccessTokenParams{}
m.UpdateTokensMock = mSessionRepositoryMockUpdateTokens{mock: m}
m.UpdateTokensMock.callArgs = []*SessionRepositoryMockUpdateTokensParams{}
t.Cleanup(m.MinimockFinish) t.Cleanup(m.MinimockFinish)
return m return m
@@ -2500,6 +2510,410 @@ func (m *SessionRepositoryMock) MinimockUpdateAccessTokenInspect() {
} }
} }
type mSessionRepositoryMockUpdateTokens struct {
optional bool
mock *SessionRepositoryMock
defaultExpectation *SessionRepositoryMockUpdateTokensExpectation
expectations []*SessionRepositoryMockUpdateTokensExpectation
callArgs []*SessionRepositoryMockUpdateTokensParams
mutex sync.RWMutex
expectedInvocations uint64
expectedInvocationsOrigin string
}
// SessionRepositoryMockUpdateTokensExpectation specifies expectation struct of the SessionRepository.UpdateTokens
type SessionRepositoryMockUpdateTokensExpectation struct {
mock *SessionRepositoryMock
params *SessionRepositoryMockUpdateTokensParams
paramPtrs *SessionRepositoryMockUpdateTokensParamPtrs
expectationOrigins SessionRepositoryMockUpdateTokensExpectationOrigins
results *SessionRepositoryMockUpdateTokensResults
returnOrigin string
Counter uint64
}
// SessionRepositoryMockUpdateTokensParams contains parameters of the SessionRepository.UpdateTokens
type SessionRepositoryMockUpdateTokensParams struct {
ctx context.Context
oldRefreshToken string
newAccessToken string
newRefreshToken string
}
// SessionRepositoryMockUpdateTokensParamPtrs contains pointers to parameters of the SessionRepository.UpdateTokens
type SessionRepositoryMockUpdateTokensParamPtrs struct {
ctx *context.Context
oldRefreshToken *string
newAccessToken *string
newRefreshToken *string
}
// SessionRepositoryMockUpdateTokensResults contains results of the SessionRepository.UpdateTokens
type SessionRepositoryMockUpdateTokensResults struct {
err error
}
// SessionRepositoryMockUpdateTokensOrigins contains origins of expectations of the SessionRepository.UpdateTokens
type SessionRepositoryMockUpdateTokensExpectationOrigins struct {
origin string
originCtx string
originOldRefreshToken string
originNewAccessToken string
originNewRefreshToken string
}
// Marks this method to be optional. The default behavior of any method with Return() is '1 or more', meaning
// the test will fail minimock's automatic final call check if the mocked method was not called at least once.
// Optional() makes method check to work in '0 or more' mode.
// It is NOT RECOMMENDED to use this option unless you really need it, as default behaviour helps to
// catch the problems when the expected method call is totally skipped during test run.
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) Optional() *mSessionRepositoryMockUpdateTokens {
mmUpdateTokens.optional = true
return mmUpdateTokens
}
// Expect sets up expected params for SessionRepository.UpdateTokens
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) Expect(ctx context.Context, oldRefreshToken string, newAccessToken string, newRefreshToken string) *mSessionRepositoryMockUpdateTokens {
if mmUpdateTokens.mock.funcUpdateTokens != nil {
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Set")
}
if mmUpdateTokens.defaultExpectation == nil {
mmUpdateTokens.defaultExpectation = &SessionRepositoryMockUpdateTokensExpectation{}
}
if mmUpdateTokens.defaultExpectation.paramPtrs != nil {
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by ExpectParams functions")
}
mmUpdateTokens.defaultExpectation.params = &SessionRepositoryMockUpdateTokensParams{ctx, oldRefreshToken, newAccessToken, newRefreshToken}
mmUpdateTokens.defaultExpectation.expectationOrigins.origin = minimock.CallerInfo(1)
for _, e := range mmUpdateTokens.expectations {
if minimock.Equal(e.params, mmUpdateTokens.defaultExpectation.params) {
mmUpdateTokens.mock.t.Fatalf("Expectation set by When has same params: %#v", *mmUpdateTokens.defaultExpectation.params)
}
}
return mmUpdateTokens
}
// ExpectCtxParam1 sets up expected param ctx for SessionRepository.UpdateTokens
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) ExpectCtxParam1(ctx context.Context) *mSessionRepositoryMockUpdateTokens {
if mmUpdateTokens.mock.funcUpdateTokens != nil {
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Set")
}
if mmUpdateTokens.defaultExpectation == nil {
mmUpdateTokens.defaultExpectation = &SessionRepositoryMockUpdateTokensExpectation{}
}
if mmUpdateTokens.defaultExpectation.params != nil {
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Expect")
}
if mmUpdateTokens.defaultExpectation.paramPtrs == nil {
mmUpdateTokens.defaultExpectation.paramPtrs = &SessionRepositoryMockUpdateTokensParamPtrs{}
}
mmUpdateTokens.defaultExpectation.paramPtrs.ctx = &ctx
mmUpdateTokens.defaultExpectation.expectationOrigins.originCtx = minimock.CallerInfo(1)
return mmUpdateTokens
}
// ExpectOldRefreshTokenParam2 sets up expected param oldRefreshToken for SessionRepository.UpdateTokens
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) ExpectOldRefreshTokenParam2(oldRefreshToken string) *mSessionRepositoryMockUpdateTokens {
if mmUpdateTokens.mock.funcUpdateTokens != nil {
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Set")
}
if mmUpdateTokens.defaultExpectation == nil {
mmUpdateTokens.defaultExpectation = &SessionRepositoryMockUpdateTokensExpectation{}
}
if mmUpdateTokens.defaultExpectation.params != nil {
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Expect")
}
if mmUpdateTokens.defaultExpectation.paramPtrs == nil {
mmUpdateTokens.defaultExpectation.paramPtrs = &SessionRepositoryMockUpdateTokensParamPtrs{}
}
mmUpdateTokens.defaultExpectation.paramPtrs.oldRefreshToken = &oldRefreshToken
mmUpdateTokens.defaultExpectation.expectationOrigins.originOldRefreshToken = minimock.CallerInfo(1)
return mmUpdateTokens
}
// ExpectNewAccessTokenParam3 sets up expected param newAccessToken for SessionRepository.UpdateTokens
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) ExpectNewAccessTokenParam3(newAccessToken string) *mSessionRepositoryMockUpdateTokens {
if mmUpdateTokens.mock.funcUpdateTokens != nil {
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Set")
}
if mmUpdateTokens.defaultExpectation == nil {
mmUpdateTokens.defaultExpectation = &SessionRepositoryMockUpdateTokensExpectation{}
}
if mmUpdateTokens.defaultExpectation.params != nil {
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Expect")
}
if mmUpdateTokens.defaultExpectation.paramPtrs == nil {
mmUpdateTokens.defaultExpectation.paramPtrs = &SessionRepositoryMockUpdateTokensParamPtrs{}
}
mmUpdateTokens.defaultExpectation.paramPtrs.newAccessToken = &newAccessToken
mmUpdateTokens.defaultExpectation.expectationOrigins.originNewAccessToken = minimock.CallerInfo(1)
return mmUpdateTokens
}
// ExpectNewRefreshTokenParam4 sets up expected param newRefreshToken for SessionRepository.UpdateTokens
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) ExpectNewRefreshTokenParam4(newRefreshToken string) *mSessionRepositoryMockUpdateTokens {
if mmUpdateTokens.mock.funcUpdateTokens != nil {
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Set")
}
if mmUpdateTokens.defaultExpectation == nil {
mmUpdateTokens.defaultExpectation = &SessionRepositoryMockUpdateTokensExpectation{}
}
if mmUpdateTokens.defaultExpectation.params != nil {
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Expect")
}
if mmUpdateTokens.defaultExpectation.paramPtrs == nil {
mmUpdateTokens.defaultExpectation.paramPtrs = &SessionRepositoryMockUpdateTokensParamPtrs{}
}
mmUpdateTokens.defaultExpectation.paramPtrs.newRefreshToken = &newRefreshToken
mmUpdateTokens.defaultExpectation.expectationOrigins.originNewRefreshToken = minimock.CallerInfo(1)
return mmUpdateTokens
}
// Inspect accepts an inspector function that has same arguments as the SessionRepository.UpdateTokens
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) Inspect(f func(ctx context.Context, oldRefreshToken string, newAccessToken string, newRefreshToken string)) *mSessionRepositoryMockUpdateTokens {
if mmUpdateTokens.mock.inspectFuncUpdateTokens != nil {
mmUpdateTokens.mock.t.Fatalf("Inspect function is already set for SessionRepositoryMock.UpdateTokens")
}
mmUpdateTokens.mock.inspectFuncUpdateTokens = f
return mmUpdateTokens
}
// Return sets up results that will be returned by SessionRepository.UpdateTokens
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) Return(err error) *SessionRepositoryMock {
if mmUpdateTokens.mock.funcUpdateTokens != nil {
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Set")
}
if mmUpdateTokens.defaultExpectation == nil {
mmUpdateTokens.defaultExpectation = &SessionRepositoryMockUpdateTokensExpectation{mock: mmUpdateTokens.mock}
}
mmUpdateTokens.defaultExpectation.results = &SessionRepositoryMockUpdateTokensResults{err}
mmUpdateTokens.defaultExpectation.returnOrigin = minimock.CallerInfo(1)
return mmUpdateTokens.mock
}
// Set uses given function f to mock the SessionRepository.UpdateTokens method
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) Set(f func(ctx context.Context, oldRefreshToken string, newAccessToken string, newRefreshToken string) (err error)) *SessionRepositoryMock {
if mmUpdateTokens.defaultExpectation != nil {
mmUpdateTokens.mock.t.Fatalf("Default expectation is already set for the SessionRepository.UpdateTokens method")
}
if len(mmUpdateTokens.expectations) > 0 {
mmUpdateTokens.mock.t.Fatalf("Some expectations are already set for the SessionRepository.UpdateTokens method")
}
mmUpdateTokens.mock.funcUpdateTokens = f
mmUpdateTokens.mock.funcUpdateTokensOrigin = minimock.CallerInfo(1)
return mmUpdateTokens.mock
}
// When sets expectation for the SessionRepository.UpdateTokens which will trigger the result defined by the following
// Then helper
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) When(ctx context.Context, oldRefreshToken string, newAccessToken string, newRefreshToken string) *SessionRepositoryMockUpdateTokensExpectation {
if mmUpdateTokens.mock.funcUpdateTokens != nil {
mmUpdateTokens.mock.t.Fatalf("SessionRepositoryMock.UpdateTokens mock is already set by Set")
}
expectation := &SessionRepositoryMockUpdateTokensExpectation{
mock: mmUpdateTokens.mock,
params: &SessionRepositoryMockUpdateTokensParams{ctx, oldRefreshToken, newAccessToken, newRefreshToken},
expectationOrigins: SessionRepositoryMockUpdateTokensExpectationOrigins{origin: minimock.CallerInfo(1)},
}
mmUpdateTokens.expectations = append(mmUpdateTokens.expectations, expectation)
return expectation
}
// Then sets up SessionRepository.UpdateTokens return parameters for the expectation previously defined by the When method
func (e *SessionRepositoryMockUpdateTokensExpectation) Then(err error) *SessionRepositoryMock {
e.results = &SessionRepositoryMockUpdateTokensResults{err}
return e.mock
}
// Times sets number of times SessionRepository.UpdateTokens should be invoked
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) Times(n uint64) *mSessionRepositoryMockUpdateTokens {
if n == 0 {
mmUpdateTokens.mock.t.Fatalf("Times of SessionRepositoryMock.UpdateTokens mock can not be zero")
}
mm_atomic.StoreUint64(&mmUpdateTokens.expectedInvocations, n)
mmUpdateTokens.expectedInvocationsOrigin = minimock.CallerInfo(1)
return mmUpdateTokens
}
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) invocationsDone() bool {
if len(mmUpdateTokens.expectations) == 0 && mmUpdateTokens.defaultExpectation == nil && mmUpdateTokens.mock.funcUpdateTokens == nil {
return true
}
totalInvocations := mm_atomic.LoadUint64(&mmUpdateTokens.mock.afterUpdateTokensCounter)
expectedInvocations := mm_atomic.LoadUint64(&mmUpdateTokens.expectedInvocations)
return totalInvocations > 0 && (expectedInvocations == 0 || expectedInvocations == totalInvocations)
}
// UpdateTokens implements mm_repository.SessionRepository
func (mmUpdateTokens *SessionRepositoryMock) UpdateTokens(ctx context.Context, oldRefreshToken string, newAccessToken string, newRefreshToken string) (err error) {
mm_atomic.AddUint64(&mmUpdateTokens.beforeUpdateTokensCounter, 1)
defer mm_atomic.AddUint64(&mmUpdateTokens.afterUpdateTokensCounter, 1)
mmUpdateTokens.t.Helper()
if mmUpdateTokens.inspectFuncUpdateTokens != nil {
mmUpdateTokens.inspectFuncUpdateTokens(ctx, oldRefreshToken, newAccessToken, newRefreshToken)
}
mm_params := SessionRepositoryMockUpdateTokensParams{ctx, oldRefreshToken, newAccessToken, newRefreshToken}
// Record call args
mmUpdateTokens.UpdateTokensMock.mutex.Lock()
mmUpdateTokens.UpdateTokensMock.callArgs = append(mmUpdateTokens.UpdateTokensMock.callArgs, &mm_params)
mmUpdateTokens.UpdateTokensMock.mutex.Unlock()
for _, e := range mmUpdateTokens.UpdateTokensMock.expectations {
if minimock.Equal(*e.params, mm_params) {
mm_atomic.AddUint64(&e.Counter, 1)
return e.results.err
}
}
if mmUpdateTokens.UpdateTokensMock.defaultExpectation != nil {
mm_atomic.AddUint64(&mmUpdateTokens.UpdateTokensMock.defaultExpectation.Counter, 1)
mm_want := mmUpdateTokens.UpdateTokensMock.defaultExpectation.params
mm_want_ptrs := mmUpdateTokens.UpdateTokensMock.defaultExpectation.paramPtrs
mm_got := SessionRepositoryMockUpdateTokensParams{ctx, oldRefreshToken, newAccessToken, newRefreshToken}
if mm_want_ptrs != nil {
if mm_want_ptrs.ctx != nil && !minimock.Equal(*mm_want_ptrs.ctx, mm_got.ctx) {
mmUpdateTokens.t.Errorf("SessionRepositoryMock.UpdateTokens got unexpected parameter ctx, expected at\n%s:\nwant: %#v\n got: %#v%s\n",
mmUpdateTokens.UpdateTokensMock.defaultExpectation.expectationOrigins.originCtx, *mm_want_ptrs.ctx, mm_got.ctx, minimock.Diff(*mm_want_ptrs.ctx, mm_got.ctx))
}
if mm_want_ptrs.oldRefreshToken != nil && !minimock.Equal(*mm_want_ptrs.oldRefreshToken, mm_got.oldRefreshToken) {
mmUpdateTokens.t.Errorf("SessionRepositoryMock.UpdateTokens got unexpected parameter oldRefreshToken, expected at\n%s:\nwant: %#v\n got: %#v%s\n",
mmUpdateTokens.UpdateTokensMock.defaultExpectation.expectationOrigins.originOldRefreshToken, *mm_want_ptrs.oldRefreshToken, mm_got.oldRefreshToken, minimock.Diff(*mm_want_ptrs.oldRefreshToken, mm_got.oldRefreshToken))
}
if mm_want_ptrs.newAccessToken != nil && !minimock.Equal(*mm_want_ptrs.newAccessToken, mm_got.newAccessToken) {
mmUpdateTokens.t.Errorf("SessionRepositoryMock.UpdateTokens got unexpected parameter newAccessToken, expected at\n%s:\nwant: %#v\n got: %#v%s\n",
mmUpdateTokens.UpdateTokensMock.defaultExpectation.expectationOrigins.originNewAccessToken, *mm_want_ptrs.newAccessToken, mm_got.newAccessToken, minimock.Diff(*mm_want_ptrs.newAccessToken, mm_got.newAccessToken))
}
if mm_want_ptrs.newRefreshToken != nil && !minimock.Equal(*mm_want_ptrs.newRefreshToken, mm_got.newRefreshToken) {
mmUpdateTokens.t.Errorf("SessionRepositoryMock.UpdateTokens got unexpected parameter newRefreshToken, expected at\n%s:\nwant: %#v\n got: %#v%s\n",
mmUpdateTokens.UpdateTokensMock.defaultExpectation.expectationOrigins.originNewRefreshToken, *mm_want_ptrs.newRefreshToken, mm_got.newRefreshToken, minimock.Diff(*mm_want_ptrs.newRefreshToken, mm_got.newRefreshToken))
}
} else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) {
mmUpdateTokens.t.Errorf("SessionRepositoryMock.UpdateTokens got unexpected parameters, expected at\n%s:\nwant: %#v\n got: %#v%s\n",
mmUpdateTokens.UpdateTokensMock.defaultExpectation.expectationOrigins.origin, *mm_want, mm_got, minimock.Diff(*mm_want, mm_got))
}
mm_results := mmUpdateTokens.UpdateTokensMock.defaultExpectation.results
if mm_results == nil {
mmUpdateTokens.t.Fatal("No results are set for the SessionRepositoryMock.UpdateTokens")
}
return (*mm_results).err
}
if mmUpdateTokens.funcUpdateTokens != nil {
return mmUpdateTokens.funcUpdateTokens(ctx, oldRefreshToken, newAccessToken, newRefreshToken)
}
mmUpdateTokens.t.Fatalf("Unexpected call to SessionRepositoryMock.UpdateTokens. %v %v %v %v", ctx, oldRefreshToken, newAccessToken, newRefreshToken)
return
}
// UpdateTokensAfterCounter returns a count of finished SessionRepositoryMock.UpdateTokens invocations
func (mmUpdateTokens *SessionRepositoryMock) UpdateTokensAfterCounter() uint64 {
return mm_atomic.LoadUint64(&mmUpdateTokens.afterUpdateTokensCounter)
}
// UpdateTokensBeforeCounter returns a count of SessionRepositoryMock.UpdateTokens invocations
func (mmUpdateTokens *SessionRepositoryMock) UpdateTokensBeforeCounter() uint64 {
return mm_atomic.LoadUint64(&mmUpdateTokens.beforeUpdateTokensCounter)
}
// Calls returns a list of arguments used in each call to SessionRepositoryMock.UpdateTokens.
// The list is in the same order as the calls were made (i.e. recent calls have a higher index)
func (mmUpdateTokens *mSessionRepositoryMockUpdateTokens) Calls() []*SessionRepositoryMockUpdateTokensParams {
mmUpdateTokens.mutex.RLock()
argCopy := make([]*SessionRepositoryMockUpdateTokensParams, len(mmUpdateTokens.callArgs))
copy(argCopy, mmUpdateTokens.callArgs)
mmUpdateTokens.mutex.RUnlock()
return argCopy
}
// MinimockUpdateTokensDone returns true if the count of the UpdateTokens invocations corresponds
// the number of defined expectations
func (m *SessionRepositoryMock) MinimockUpdateTokensDone() bool {
if m.UpdateTokensMock.optional {
// Optional methods provide '0 or more' call count restriction.
return true
}
for _, e := range m.UpdateTokensMock.expectations {
if mm_atomic.LoadUint64(&e.Counter) < 1 {
return false
}
}
return m.UpdateTokensMock.invocationsDone()
}
// MinimockUpdateTokensInspect logs each unmet expectation
func (m *SessionRepositoryMock) MinimockUpdateTokensInspect() {
for _, e := range m.UpdateTokensMock.expectations {
if mm_atomic.LoadUint64(&e.Counter) < 1 {
m.t.Errorf("Expected call to SessionRepositoryMock.UpdateTokens at\n%s with params: %#v", e.expectationOrigins.origin, *e.params)
}
}
afterUpdateTokensCounter := mm_atomic.LoadUint64(&m.afterUpdateTokensCounter)
// if default expectation was set then invocations count should be greater than zero
if m.UpdateTokensMock.defaultExpectation != nil && afterUpdateTokensCounter < 1 {
if m.UpdateTokensMock.defaultExpectation.params == nil {
m.t.Errorf("Expected call to SessionRepositoryMock.UpdateTokens at\n%s", m.UpdateTokensMock.defaultExpectation.returnOrigin)
} else {
m.t.Errorf("Expected call to SessionRepositoryMock.UpdateTokens at\n%s with params: %#v", m.UpdateTokensMock.defaultExpectation.expectationOrigins.origin, *m.UpdateTokensMock.defaultExpectation.params)
}
}
// if func was set then invocations count should be greater than zero
if m.funcUpdateTokens != nil && afterUpdateTokensCounter < 1 {
m.t.Errorf("Expected call to SessionRepositoryMock.UpdateTokens at\n%s", m.funcUpdateTokensOrigin)
}
if !m.UpdateTokensMock.invocationsDone() && afterUpdateTokensCounter > 0 {
m.t.Errorf("Expected %d calls to SessionRepositoryMock.UpdateTokens at\n%s but found %d calls",
mm_atomic.LoadUint64(&m.UpdateTokensMock.expectedInvocations), m.UpdateTokensMock.expectedInvocationsOrigin, afterUpdateTokensCounter)
}
}
// MinimockFinish checks that all mocked methods have been called the expected number of times // MinimockFinish checks that all mocked methods have been called the expected number of times
func (m *SessionRepositoryMock) MinimockFinish() { func (m *SessionRepositoryMock) MinimockFinish() {
m.finishOnce.Do(func() { m.finishOnce.Do(func() {
@@ -2517,6 +2931,8 @@ func (m *SessionRepositoryMock) MinimockFinish() {
m.MinimockRevokeByAccessTokenInspect() m.MinimockRevokeByAccessTokenInspect()
m.MinimockUpdateAccessTokenInspect() m.MinimockUpdateAccessTokenInspect()
m.MinimockUpdateTokensInspect()
} }
}) })
} }
@@ -2546,5 +2962,6 @@ func (m *SessionRepositoryMock) minimockDone() bool {
m.MinimockIsAccessTokenValidDone() && m.MinimockIsAccessTokenValidDone() &&
m.MinimockRevokeDone() && m.MinimockRevokeDone() &&
m.MinimockRevokeByAccessTokenDone() && m.MinimockRevokeByAccessTokenDone() &&
m.MinimockUpdateAccessTokenDone() m.MinimockUpdateAccessTokenDone() &&
m.MinimockUpdateTokensDone()
} }

View File

@@ -26,6 +26,7 @@ type SessionRepository interface {
Create(ctx context.Context, session *model.Session) error Create(ctx context.Context, session *model.Session) error
FindByRefreshToken(ctx context.Context, token string) (*model.Session, error) FindByRefreshToken(ctx context.Context, token string) (*model.Session, error)
UpdateAccessToken(ctx context.Context, refreshToken, newAccessToken string) error UpdateAccessToken(ctx context.Context, refreshToken, newAccessToken string) error
UpdateTokens(ctx context.Context, oldRefreshToken, newAccessToken, newRefreshToken string) error
Revoke(ctx context.Context, refreshToken string) error Revoke(ctx context.Context, refreshToken string) error
RevokeByAccessToken(ctx context.Context, accessToken string) error RevokeByAccessToken(ctx context.Context, accessToken string) error
IsAccessTokenValid(ctx context.Context, accessToken string) (bool, error) IsAccessTokenValid(ctx context.Context, accessToken string) (bool, error)

View File

@@ -96,6 +96,32 @@ func (r *sessionRepository) UpdateAccessToken(ctx context.Context, refreshToken,
return nil return nil
} }
func (r *sessionRepository) UpdateTokens(ctx context.Context, oldRefreshToken, newAccessToken, newRefreshToken string) error {
query := r.qb.Update("sessions").
Set("access_token", newAccessToken).
Set("refresh_token", newRefreshToken).
Where(sq.And{
sq.Eq{"refresh_token": oldRefreshToken},
sq.Expr("revoked_at IS NULL"),
})
sqlQuery, args, err := query.ToSql()
if err != nil {
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
}
result, err := r.pool.Exec(ctx, sqlQuery, args...)
if err != nil {
return errs.NewInternalError(errs.DatabaseError, "failed to update tokens", err)
}
if result.RowsAffected() == 0 {
return errs.NewBusinessError(errs.RefreshInvalid, "refresh token is invalid or already used")
}
return nil
}
func (r *sessionRepository) Revoke(ctx context.Context, refreshToken string) error { func (r *sessionRepository) Revoke(ctx context.Context, refreshToken string) error {
query := r.qb.Update("sessions"). query := r.qb.Update("sessions").
Set("revoked_at", time.Now()). Set("revoked_at", time.Now()).

View File

@@ -9,6 +9,7 @@ import (
"git.techease.ru/Smart-search/smart-search-back/pkg/crypto" "git.techease.ru/Smart-search/smart-search-back/pkg/crypto"
"git.techease.ru/Smart-search/smart-search-back/pkg/errors" "git.techease.ru/Smart-search/smart-search-back/pkg/errors"
"git.techease.ru/Smart-search/smart-search-back/pkg/jwt" "git.techease.ru/Smart-search/smart-search-back/pkg/jwt"
"git.techease.ru/Smart-search/smart-search-back/pkg/validation"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
) )
@@ -40,8 +41,7 @@ func (s *authService) Login(ctx context.Context, email, password, ip, userAgent
return "", "", err return "", "", err
} }
passwordHash := crypto.PasswordHash(password) if !crypto.PasswordVerify(password, user.PasswordHash) {
if user.PasswordHash != passwordHash {
return "", "", errors.NewBusinessError(errors.AuthInvalidCredentials, "Invalid email or password") return "", "", errors.NewBusinessError(errors.AuthInvalidCredentials, "Invalid email or password")
} }
@@ -71,22 +71,27 @@ func (s *authService) Login(ctx context.Context, email, password, ip, userAgent
return accessToken, refreshToken, nil return accessToken, refreshToken, nil
} }
func (s *authService) Refresh(ctx context.Context, refreshToken string) (string, error) { func (s *authService) Refresh(ctx context.Context, refreshToken string) (string, string, error) {
session, err := s.sessionRepo.FindByRefreshToken(ctx, refreshToken) session, err := s.sessionRepo.FindByRefreshToken(ctx, refreshToken)
if err != nil { if err != nil {
return "", err return "", "", err
} }
newAccessToken, err := jwt.GenerateAccessToken(session.UserID, s.jwtSecret) newAccessToken, err := jwt.GenerateAccessToken(session.UserID, s.jwtSecret)
if err != nil { if err != nil {
return "", errors.NewInternalError(errors.InternalError, "failed to generate access token", err) return "", "", errors.NewInternalError(errors.InternalError, "failed to generate access token", err)
} }
if err := s.sessionRepo.UpdateAccessToken(ctx, refreshToken, newAccessToken); err != nil { newRefreshToken, err := jwt.GenerateRefreshToken(session.UserID, s.jwtSecret)
return "", err if err != nil {
return "", "", errors.NewInternalError(errors.InternalError, "failed to generate refresh token", err)
} }
return newAccessToken, nil if err := s.sessionRepo.UpdateTokens(ctx, refreshToken, newAccessToken, newRefreshToken); err != nil {
return "", "", err
}
return newAccessToken, newRefreshToken, nil
} }
func (s *authService) Validate(ctx context.Context, accessToken string) (int, error) { func (s *authService) Validate(ctx context.Context, accessToken string) (int, error) {
@@ -121,6 +126,10 @@ func (s *authService) Logout(ctx context.Context, accessToken string) error {
} }
func (s *authService) Register(ctx context.Context, email, password, name, phone string, inviteCode int64, ip, userAgent string) (accessToken, refreshToken string, err error) { func (s *authService) Register(ctx context.Context, email, password, name, phone string, inviteCode int64, ip, userAgent string) (accessToken, refreshToken string, err error) {
if err := validation.ValidateRegistration(email, password, name, phone); err != nil {
return "", "", err
}
_, err = s.inviteRepo.FindActiveByCode(ctx, inviteCode) _, err = s.inviteRepo.FindActiveByCode(ctx, inviteCode)
if err != nil { if err != nil {
return "", "", err return "", "", err

View File

@@ -11,7 +11,7 @@ import (
type AuthService interface { type AuthService interface {
Register(ctx context.Context, email, password, name, phone string, inviteCode int64, ip, userAgent string) (accessToken, refreshToken string, err error) Register(ctx context.Context, email, password, name, phone string, inviteCode int64, ip, userAgent string) (accessToken, refreshToken string, err error)
Login(ctx context.Context, email, password, ip, userAgent string) (accessToken, refreshToken string, err error) Login(ctx context.Context, email, password, ip, userAgent string) (accessToken, refreshToken string, err error)
Refresh(ctx context.Context, refreshToken string) (string, error) Refresh(ctx context.Context, refreshToken string) (newAccessToken, newRefreshToken string, err error)
Validate(ctx context.Context, accessToken string) (int, error) Validate(ctx context.Context, accessToken string) (int, error)
Logout(ctx context.Context, accessToken string) error Logout(ctx context.Context, accessToken string) error
} }

View File

@@ -10,6 +10,7 @@ import (
"git.techease.ru/Smart-search/smart-search-back/internal/repository" "git.techease.ru/Smart-search/smart-search-back/internal/repository"
"git.techease.ru/Smart-search/smart-search-back/pkg/errors" "git.techease.ru/Smart-search/smart-search-back/pkg/errors"
"git.techease.ru/Smart-search/smart-search-back/pkg/fileparser" "git.techease.ru/Smart-search/smart-search-back/pkg/fileparser"
"git.techease.ru/Smart-search/smart-search-back/pkg/validation"
"github.com/google/uuid" "github.com/google/uuid"
"github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5"
) )
@@ -45,6 +46,14 @@ func NewRequestService(
} }
func (s *requestService) CreateTZ(ctx context.Context, userID int, requestTxt string, fileData []byte, fileName string) (uuid.UUID, string, error) { func (s *requestService) CreateTZ(ctx context.Context, userID int, requestTxt string, fileData []byte, fileName string) (uuid.UUID, string, error) {
if err := validation.ValidateRequestTxt(requestTxt); err != nil {
return uuid.Nil, "", err
}
if err := validation.ValidateFileSize(len(fileData)); err != nil {
return uuid.Nil, "", err
}
combinedText := requestTxt combinedText := requestTxt
if len(fileData) > 0 && fileName != "" { if len(fileData) > 0 && fileName != "" {

View File

@@ -267,22 +267,24 @@ func (s *Suite) TestAuthService_Login_EmptyIPAndUserAgent() {
func (s *Suite) TestAuthService_Refresh_Success() { func (s *Suite) TestAuthService_Refresh_Success() {
session := createTestSession(1) session := createTestSession(1)
s.sessionRepo.FindByRefreshTokenMock.Return(session, nil) s.sessionRepo.FindByRefreshTokenMock.Return(session, nil)
s.sessionRepo.UpdateAccessTokenMock.Return(nil) s.sessionRepo.UpdateTokensMock.Return(nil)
accessToken, err := s.authService.Refresh(s.ctx, "test-refresh-token") accessToken, refreshToken, err := s.authService.Refresh(s.ctx, "test-refresh-token")
s.NoError(err) s.NoError(err)
s.NotEmpty(accessToken) s.NotEmpty(accessToken)
s.NotEmpty(refreshToken)
} }
func (s *Suite) TestAuthService_Refresh_RefreshInvalid() { func (s *Suite) TestAuthService_Refresh_RefreshInvalid() {
err := apperrors.NewBusinessError(apperrors.RefreshInvalid, "refresh token is invalid or expired") err := apperrors.NewBusinessError(apperrors.RefreshInvalid, "refresh token is invalid or expired")
s.sessionRepo.FindByRefreshTokenMock.Return(nil, err) s.sessionRepo.FindByRefreshTokenMock.Return(nil, err)
accessToken, refreshErr := s.authService.Refresh(s.ctx, "invalid-token") accessToken, refreshToken, refreshErr := s.authService.Refresh(s.ctx, "invalid-token")
s.Error(refreshErr) s.Error(refreshErr)
s.Empty(accessToken) s.Empty(accessToken)
s.Empty(refreshToken)
var appErr *apperrors.AppError var appErr *apperrors.AppError
s.True(errors.As(refreshErr, &appErr)) s.True(errors.As(refreshErr, &appErr))
@@ -293,10 +295,11 @@ func (s *Suite) TestAuthService_Refresh_DatabaseError_OnFindSession() {
dbErr := apperrors.NewInternalError(apperrors.DatabaseError, "failed to find session", nil) dbErr := apperrors.NewInternalError(apperrors.DatabaseError, "failed to find session", nil)
s.sessionRepo.FindByRefreshTokenMock.Return(nil, dbErr) s.sessionRepo.FindByRefreshTokenMock.Return(nil, dbErr)
accessToken, err := s.authService.Refresh(s.ctx, "test-refresh-token") accessToken, refreshToken, err := s.authService.Refresh(s.ctx, "test-refresh-token")
s.Error(err) s.Error(err)
s.Empty(accessToken) s.Empty(accessToken)
s.Empty(refreshToken)
var appErr *apperrors.AppError var appErr *apperrors.AppError
s.True(errors.As(err, &appErr)) s.True(errors.As(err, &appErr))
@@ -307,34 +310,37 @@ func (s *Suite) TestAuthService_Refresh_EmptyToken() {
err := apperrors.NewBusinessError(apperrors.RefreshInvalid, "refresh token is invalid or expired") err := apperrors.NewBusinessError(apperrors.RefreshInvalid, "refresh token is invalid or expired")
s.sessionRepo.FindByRefreshTokenMock.Return(nil, err) s.sessionRepo.FindByRefreshTokenMock.Return(nil, err)
accessToken, refreshErr := s.authService.Refresh(s.ctx, "") accessToken, refreshToken, refreshErr := s.authService.Refresh(s.ctx, "")
s.Error(refreshErr) s.Error(refreshErr)
s.Empty(accessToken) s.Empty(accessToken)
s.Empty(refreshToken)
} }
func (s *Suite) TestAuthService_Refresh_UpdateAccessTokenError() { func (s *Suite) TestAuthService_Refresh_UpdateTokensError() {
session := createTestSession(1) session := createTestSession(1)
s.sessionRepo.FindByRefreshTokenMock.Return(session, nil) s.sessionRepo.FindByRefreshTokenMock.Return(session, nil)
dbErr := apperrors.NewInternalError(apperrors.DatabaseError, "failed to update access token", nil) dbErr := apperrors.NewInternalError(apperrors.DatabaseError, "failed to update tokens", nil)
s.sessionRepo.UpdateAccessTokenMock.Return(dbErr) s.sessionRepo.UpdateTokensMock.Return(dbErr)
accessToken, err := s.authService.Refresh(s.ctx, "test-refresh-token") accessToken, refreshToken, err := s.authService.Refresh(s.ctx, "test-refresh-token")
s.Error(err) s.Error(err)
s.Empty(accessToken) s.Empty(accessToken)
s.Empty(refreshToken)
} }
func (s *Suite) TestAuthService_Refresh_UserIDZero() { func (s *Suite) TestAuthService_Refresh_UserIDZero() {
session := createTestSession(0) session := createTestSession(0)
s.sessionRepo.FindByRefreshTokenMock.Return(session, nil) s.sessionRepo.FindByRefreshTokenMock.Return(session, nil)
s.sessionRepo.UpdateAccessTokenMock.Return(nil) s.sessionRepo.UpdateTokensMock.Return(nil)
accessToken, err := s.authService.Refresh(s.ctx, "test-refresh-token") accessToken, refreshToken, err := s.authService.Refresh(s.ctx, "test-refresh-token")
s.NoError(err) s.NoError(err)
s.NotEmpty(accessToken) s.NotEmpty(accessToken)
s.NotEmpty(refreshToken)
} }
func (s *Suite) TestAuthService_Validate_Success() { func (s *Suite) TestAuthService_Validate_Success() {

View File

@@ -6,11 +6,12 @@ import (
"crypto/hmac" "crypto/hmac"
"crypto/rand" "crypto/rand"
"crypto/sha256" "crypto/sha256"
"crypto/sha512"
"encoding/hex" "encoding/hex"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
"golang.org/x/crypto/bcrypt"
) )
type Crypto struct { type Crypto struct {
@@ -31,10 +32,19 @@ func (c *Crypto) EmailHash(email string) string {
return hex.EncodeToString(h.Sum(nil)) return hex.EncodeToString(h.Sum(nil))
} }
const bcryptCost = 12
func PasswordHash(password string) string { func PasswordHash(password string) string {
h := sha512.New() hash, err := bcrypt.GenerateFromPassword([]byte(password), bcryptCost)
h.Write([]byte(password)) if err != nil {
return hex.EncodeToString(h.Sum(nil)) return ""
}
return string(hash)
}
func PasswordVerify(password, hash string) bool {
err := bcrypt.CompareHashAndPassword([]byte(hash), []byte(password))
return err == nil
} }
func (c *Crypto) getKey() []byte { func (c *Crypto) getKey() []byte {

View File

@@ -15,6 +15,13 @@ const (
UnsupportedFileFormat = "UNSUPPORTED_FILE_FORMAT" UnsupportedFileFormat = "UNSUPPORTED_FILE_FORMAT"
FileProcessingError = "FILE_PROCESSING_ERROR" FileProcessingError = "FILE_PROCESSING_ERROR"
ValidationInvalidEmail = "VALIDATION_INVALID_EMAIL"
ValidationInvalidPassword = "VALIDATION_INVALID_PASSWORD"
ValidationInvalidPhone = "VALIDATION_INVALID_PHONE"
ValidationInvalidName = "VALIDATION_INVALID_NAME"
ValidationFileTooLarge = "VALIDATION_FILE_TOO_LARGE"
ValidationRequestTooLong = "VALIDATION_REQUEST_TOO_LONG"
DatabaseError = "DATABASE_ERROR" DatabaseError = "DATABASE_ERROR"
EncryptionError = "ENCRYPTION_ERROR" EncryptionError = "ENCRYPTION_ERROR"
AIAPIError = "AI_API_ERROR" AIAPIError = "AI_API_ERROR"

View File

@@ -0,0 +1,156 @@
package validation
import (
"net/mail"
"regexp"
"strings"
"unicode"
"git.techease.ru/Smart-search/smart-search-back/pkg/errors"
)
const (
MinPasswordLength = 8
MaxPasswordLength = 128
MaxEmailLength = 254
MaxNameLength = 100
MaxPhoneLength = 20
MaxRequestTxtLen = 50000
MaxFileSizeBytes = 10 * 1024 * 1024
)
var (
phoneRegex = regexp.MustCompile(`^\+?[1-9]\d{6,14}$`)
)
func ValidateEmail(email string) error {
if email == "" {
return errors.NewBusinessError(errors.ValidationInvalidEmail, "email is required")
}
if len(email) > MaxEmailLength {
return errors.NewBusinessError(errors.ValidationInvalidEmail, "email is too long")
}
_, err := mail.ParseAddress(email)
if err != nil {
return errors.NewBusinessError(errors.ValidationInvalidEmail, "invalid email format")
}
parts := strings.Split(email, "@")
if len(parts) != 2 || parts[0] == "" || parts[1] == "" {
return errors.NewBusinessError(errors.ValidationInvalidEmail, "invalid email format")
}
if strings.Contains(parts[1], "..") {
return errors.NewBusinessError(errors.ValidationInvalidEmail, "invalid email format")
}
return nil
}
func ValidatePassword(password string) error {
if password == "" {
return errors.NewBusinessError(errors.ValidationInvalidPassword, "password is required")
}
if len(password) < MinPasswordLength {
return errors.NewBusinessError(errors.ValidationInvalidPassword, "password must be at least 8 characters")
}
if len(password) > MaxPasswordLength {
return errors.NewBusinessError(errors.ValidationInvalidPassword, "password is too long")
}
var hasUpper, hasLower, hasDigit bool
for _, c := range password {
switch {
case unicode.IsUpper(c):
hasUpper = true
case unicode.IsLower(c):
hasLower = true
case unicode.IsDigit(c):
hasDigit = true
}
}
if !hasUpper || !hasLower || !hasDigit {
return errors.NewBusinessError(errors.ValidationInvalidPassword, "password must contain uppercase, lowercase and digit")
}
return nil
}
func ValidatePhone(phone string) error {
if phone == "" {
return errors.NewBusinessError(errors.ValidationInvalidPhone, "phone is required")
}
if len(phone) > MaxPhoneLength {
return errors.NewBusinessError(errors.ValidationInvalidPhone, "phone is too long")
}
cleaned := strings.ReplaceAll(phone, " ", "")
cleaned = strings.ReplaceAll(cleaned, "-", "")
cleaned = strings.ReplaceAll(cleaned, "(", "")
cleaned = strings.ReplaceAll(cleaned, ")", "")
if !phoneRegex.MatchString(cleaned) {
return errors.NewBusinessError(errors.ValidationInvalidPhone, "invalid phone format")
}
return nil
}
func ValidateName(name string) error {
if name == "" {
return errors.NewBusinessError(errors.ValidationInvalidName, "name is required")
}
if len(name) > MaxNameLength {
return errors.NewBusinessError(errors.ValidationInvalidName, "name is too long")
}
trimmed := strings.TrimSpace(name)
if trimmed == "" {
return errors.NewBusinessError(errors.ValidationInvalidName, "name cannot be only whitespace")
}
return nil
}
func ValidateRequestTxt(txt string) error {
if len(txt) > MaxRequestTxtLen {
return errors.NewBusinessError(errors.ValidationRequestTooLong, "request text exceeds 50000 characters limit")
}
return nil
}
func ValidateFileSize(size int) error {
if size > MaxFileSizeBytes {
return errors.NewBusinessError(errors.ValidationFileTooLarge, "file size exceeds 10MB limit")
}
return nil
}
func ValidateRegistration(email, password, name, phone string) error {
if err := ValidateEmail(email); err != nil {
return err
}
if err := ValidatePassword(password); err != nil {
return err
}
if err := ValidateName(name); err != nil {
return err
}
if err := ValidatePhone(phone); err != nil {
return err
}
return nil
}

View File

@@ -0,0 +1,222 @@
package validation
import (
"strings"
"testing"
"git.techease.ru/Smart-search/smart-search-back/pkg/errors"
"github.com/stretchr/testify/assert"
)
func TestValidateEmail(t *testing.T) {
tests := []struct {
name string
email string
wantErr bool
wantCode string
}{
{"valid email", "test@example.com", false, ""},
{"valid with subdomain", "test@sub.example.com", false, ""},
{"valid with plus", "test+tag@example.com", false, ""},
{"valid with dots", "test.name@example.com", false, ""},
{"empty", "", true, errors.ValidationInvalidEmail},
{"no at sign", "testexample.com", true, errors.ValidationInvalidEmail},
{"no domain", "test@", true, errors.ValidationInvalidEmail},
{"no local part", "@example.com", true, errors.ValidationInvalidEmail},
{"double dots in domain", "test@example..com", true, errors.ValidationInvalidEmail},
{"too long", strings.Repeat("a", 255) + "@example.com", true, errors.ValidationInvalidEmail},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateEmail(tt.email)
if tt.wantErr {
assert.Error(t, err)
if appErr, ok := err.(*errors.AppError); ok {
assert.Equal(t, tt.wantCode, appErr.Code)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidatePassword(t *testing.T) {
tests := []struct {
name string
password string
wantErr bool
wantCode string
}{
{"valid password", "Abcd1234", false, ""},
{"valid with special chars", "Abcd1234!", false, ""},
{"empty", "", true, errors.ValidationInvalidPassword},
{"too short", "Ab1", true, errors.ValidationInvalidPassword},
{"no uppercase", "abcd1234", true, errors.ValidationInvalidPassword},
{"no lowercase", "ABCD1234", true, errors.ValidationInvalidPassword},
{"no digit", "Abcdefgh", true, errors.ValidationInvalidPassword},
{"only digits", "12345678", true, errors.ValidationInvalidPassword},
{"only lowercase", "abcdefgh", true, errors.ValidationInvalidPassword},
{"only uppercase", "ABCDEFGH", true, errors.ValidationInvalidPassword},
{"too long", strings.Repeat("Aa1", 50), true, errors.ValidationInvalidPassword},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidatePassword(tt.password)
if tt.wantErr {
assert.Error(t, err)
if appErr, ok := err.(*errors.AppError); ok {
assert.Equal(t, tt.wantCode, appErr.Code)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidatePhone(t *testing.T) {
tests := []struct {
name string
phone string
wantErr bool
wantCode string
}{
{"valid international", "+1234567890", false, ""},
{"valid with country code", "+79123456789", false, ""},
{"valid without plus", "1234567890", false, ""},
{"empty", "", true, errors.ValidationInvalidPhone},
{"too short", "123", true, errors.ValidationInvalidPhone},
{"letters", "abcdefgh", true, errors.ValidationInvalidPhone},
{"too long", "+123456789012345678901", true, errors.ValidationInvalidPhone},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidatePhone(tt.phone)
if tt.wantErr {
assert.Error(t, err)
if appErr, ok := err.(*errors.AppError); ok {
assert.Equal(t, tt.wantCode, appErr.Code)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateName(t *testing.T) {
tests := []struct {
name string
value string
wantErr bool
wantCode string
}{
{"valid name", "John Doe", false, ""},
{"valid cyrillic", "Иван Иванов", false, ""},
{"empty", "", true, errors.ValidationInvalidName},
{"only whitespace", " ", true, errors.ValidationInvalidName},
{"too long", strings.Repeat("a", 101), true, errors.ValidationInvalidName},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateName(tt.value)
if tt.wantErr {
assert.Error(t, err)
if appErr, ok := err.(*errors.AppError); ok {
assert.Equal(t, tt.wantCode, appErr.Code)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateRequestTxt(t *testing.T) {
tests := []struct {
name string
txt string
wantErr bool
wantCode string
}{
{"valid short", "Test request", false, ""},
{"empty is valid", "", false, ""},
{"max length", strings.Repeat("a", MaxRequestTxtLen), false, ""},
{"too long", strings.Repeat("a", MaxRequestTxtLen+1), true, errors.ValidationRequestTooLong},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateRequestTxt(tt.txt)
if tt.wantErr {
assert.Error(t, err)
if appErr, ok := err.(*errors.AppError); ok {
assert.Equal(t, tt.wantCode, appErr.Code)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateFileSize(t *testing.T) {
tests := []struct {
name string
size int
wantErr bool
wantCode string
}{
{"zero", 0, false, ""},
{"small file", 1024, false, ""},
{"max size", MaxFileSizeBytes, false, ""},
{"too large", MaxFileSizeBytes + 1, true, errors.ValidationFileTooLarge},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateFileSize(tt.size)
if tt.wantErr {
assert.Error(t, err)
if appErr, ok := err.(*errors.AppError); ok {
assert.Equal(t, tt.wantCode, appErr.Code)
}
} else {
assert.NoError(t, err)
}
})
}
}
func TestValidateRegistration(t *testing.T) {
tests := []struct {
name string
email string
password string
userName string
phone string
wantErr bool
}{
{"valid", "test@example.com", "Abcd1234", "John Doe", "+1234567890", false},
{"invalid email", "invalid", "Abcd1234", "John Doe", "+1234567890", true},
{"invalid password", "test@example.com", "weak", "John Doe", "+1234567890", true},
{"invalid name", "test@example.com", "Abcd1234", "", "+1234567890", true},
{"invalid phone", "test@example.com", "Abcd1234", "John Doe", "abc", true},
}
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
err := ValidateRegistration(tt.email, tt.password, tt.userName, tt.phone)
if tt.wantErr {
assert.Error(t, err)
} else {
assert.NoError(t, err)
}
})
}
}

View File

@@ -154,6 +154,7 @@ func (s *IntegrationSuite) createTestUser(email, password string) {
emailHash := cryptoHelper.EmailHash(email) emailHash := cryptoHelper.EmailHash(email)
passwordHash := crypto.PasswordHash(password) passwordHash := crypto.PasswordHash(password)
s.Require().NotEmpty(passwordHash, "password hash should not be empty")
query := ` query := `
INSERT INTO users (email, email_hash, password_hash, phone, user_name, company_name, balance, payment_status, invites_issued, invites_limit) INSERT INTO users (email, email_hash, password_hash, phone, user_name, company_name, balance, payment_status, invites_issued, invites_limit)

1206
tests/security_test.go Normal file

File diff suppressed because it is too large Load Diff