diff --git a/DEPLOYMENT.md b/DEPLOYMENT.md index 0fe96c2..06615e7 100644 --- a/DEPLOYMENT.md +++ b/DEPLOYMENT.md @@ -128,7 +128,7 @@ make migrate-create name=add_new_field Сервер запускается на порту 9091 и предоставляет следующие сервисы: -- **AuthService** - аутентификация +- **AuthService** - аутентификация и регистрация (Register, Login, Refresh, Validate, Logout) - **UserService** - управление пользователями - **InviteService** - управление инвайт-кодами - **RequestService** - управление запросами diff --git a/GRPC_SERVICES.md b/GRPC_SERVICES.md index 43d9ed0..724056c 100644 --- a/GRPC_SERVICES.md +++ b/GRPC_SERVICES.md @@ -25,6 +25,7 @@ internal/grpc/ | Метод | Описание | |-------|----------| +| `Register` | Регистрация нового пользователя по инвайт-коду | | `Login` | Аутентификация пользователя (email + password) | | `Refresh` | Обновление access token по refresh token | | `Validate` | Валидация access token | @@ -153,6 +154,8 @@ resp, err := supplierClient.ExportExcel(ctx, req) | `USER_NOT_FOUND` | `NotFound` | | `REQUEST_NOT_FOUND` | `NotFound` | | `INVITE_LIMIT_REACHED` | `ResourceExhausted` | +| `INVITE_INVALID_OR_EXPIRED` | `FailedPrecondition` | +| `EMAIL_ALREADY_EXISTS` | `AlreadyExists` | | `INSUFFICIENT_BALANCE` | `FailedPrecondition` | | Внутренние ошибки | `Internal` (без деталей) | @@ -238,7 +241,7 @@ grpc: ## Статистика -- **Всего gRPC методов**: 16 +- **Всего gRPC методов**: 17 - **Всего handlers**: 5 -- **Строк кода handlers**: ~371 +- **Строк кода handlers**: ~390 - **Proto файлов**: 5 diff --git a/Makefile b/Makefile index 8efb2cb..8c4ece1 100644 --- a/Makefile +++ b/Makefile @@ -85,7 +85,7 @@ test: test-integration: @echo "Running integration tests with testcontainers..." @echo "This may take several minutes..." - go test -v -timeout=10m ./internal/grpc/tests/... + go test -v -timeout=10m ./tests/... # Default DB URL for local development DB_URL ?= postgres://postgres:password@localhost:5432/b2b_search?sslmode=disable diff --git a/README.md b/README.md index 33d422a..cec17af 100644 --- a/README.md +++ b/README.md @@ -37,8 +37,8 @@ Backend микросервис для системы поиска поставщ ### gRPC Services -**5 gRPC сервисов с 16 методами**: -- `AuthService` - аутентификация (Login, Refresh, Validate, Logout) +**5 gRPC сервисов с 17 методами**: +- `AuthService` - аутентификация (Register, Login, Refresh, Validate, Logout) - `UserService` - информация о пользователе и статистика - `InviteService` - управление инвайт-кодами - `RequestService` - создание и управление заявками с AI diff --git a/TESTING.md b/TESTING.md index 05a0234..41ed28b 100644 --- a/TESTING.md +++ b/TESTING.md @@ -100,6 +100,42 @@ func TestAuthService_Login_Success(t *testing.T) { // Minimock автоматически проверит что все ожидания выполнены } + +func TestAuthService_Register_Success(t *testing.T) { + mockUserRepo := mocks.NewUserRepositoryMock(t) + mockSessionRepo := mocks.NewSessionRepositoryMock(t) + mockInviteRepo := mocks.NewInviteRepositoryMock(t) + + mockInviteRepo.FindActiveByCodeMock.Expect(context.Background(), int64(123456)).Return(&model.InviteCode{ + Code: 123456, + CanBeUsedCount: 5, + UsedCount: 0, + }, nil) + + mockUserRepo.FindByEmailHashMock.Expect(context.Background(), "email_hash").Return(nil, + errors.NewBusinessError(errors.UserNotFound, "user not found")) + + mockUserRepo.CreateTxMock.Return(nil) + mockInviteRepo.DecrementCanBeUsedCountTxMock.Return(nil) + mockSessionRepo.CreateMock.Return(nil) + + authService := service.NewAuthService(mockUserRepo, mockSessionRepo, mockInviteRepo, txManager, "secret", "cryptosecret") + + accessToken, refreshToken, err := authService.Register( + context.Background(), + "newuser@example.com", + "password123", + "New User", + "+1234567890", + 123456, + "127.0.0.1", + "test-agent", + ) + + assert.NoError(t, err) + assert.NotEmpty(t, accessToken) + assert.NotEmpty(t, refreshToken) +} ``` ### Преимущества minimock diff --git a/api/proto/auth/auth.proto b/api/proto/auth/auth.proto index 9b08765..2a3b7f3 100644 --- a/api/proto/auth/auth.proto +++ b/api/proto/auth/auth.proto @@ -3,6 +3,7 @@ package auth; option go_package = "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth"; service AuthService { + rpc Register(RegisterRequest) returns (RegisterResponse); rpc Login(LoginRequest) returns (LoginResponse); rpc Refresh(RefreshRequest) returns (RefreshResponse); rpc Validate(ValidateRequest) returns (ValidateResponse); @@ -48,3 +49,18 @@ message LogoutRequest { message LogoutResponse { bool success = 1; } + +message RegisterRequest { + string email = 1; + string password = 2; + string name = 3; + string phone = 4; + int64 invite_code = 5; + string ip = 6; + string user_agent = 7; +} + +message RegisterResponse { + string access_token = 1; + string refresh_token = 2; +} diff --git a/internal/grpc/auth_handler.go b/internal/grpc/auth_handler.go index 837c48f..5718fa8 100644 --- a/internal/grpc/auth_handler.go +++ b/internal/grpc/auth_handler.go @@ -9,6 +9,27 @@ import ( "go.uber.org/zap" ) +func (h *AuthHandler) Register(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) { + accessToken, refreshToken, err := h.authService.Register( + ctx, + req.Email, + req.Password, + req.Name, + req.Phone, + req.InviteCode, + req.Ip, + req.UserAgent, + ) + if err != nil { + return nil, errors.ToGRPCError(err, h.logger, "AuthService.Register") + } + + return &pb.RegisterResponse{ + AccessToken: accessToken, + RefreshToken: refreshToken, + }, nil +} + func (h *AuthHandler) Login(ctx context.Context, req *pb.LoginRequest) (*pb.LoginResponse, error) { accessToken, refreshToken, err := h.authService.Login( ctx, diff --git a/internal/grpc/server.go b/internal/grpc/server.go index 2b61c3a..5333000 100644 --- a/internal/grpc/server.go +++ b/internal/grpc/server.go @@ -58,7 +58,7 @@ func NewHandlers(pool *pgxpool.Pool, jwtSecret, cryptoSecret, openAIKey, perplex openAIClient := ai.NewOpenAIClient(openAIKey) perplexityClient := ai.NewPerplexityClient(perplexityKey) - authService := service.NewAuthService(userRepo, sessionRepo, jwtSecret, cryptoSecret) + authService := service.NewAuthService(userRepo, sessionRepo, inviteRepo, txManager, jwtSecret, cryptoSecret) userService := service.NewUserService(userRepo, requestRepo, cryptoSecret) inviteService := service.NewInviteService(inviteRepo, userRepo, txManager) requestService := service.NewRequestService(requestRepo, supplierRepo, tokenUsageRepo, userRepo, openAIClient, perplexityClient, txManager) diff --git a/internal/mocks/auth_service_mock.go b/internal/mocks/auth_service_mock.go index af3898d..0085737 100644 --- a/internal/mocks/auth_service_mock.go +++ b/internal/mocks/auth_service_mock.go @@ -39,6 +39,13 @@ type AuthServiceMock struct { beforeRefreshCounter uint64 RefreshMock mAuthServiceMockRefresh + funcRegister func(ctx context.Context, email string, password string, name string, phone string, inviteCode int64, ip string, userAgent string) (accessToken string, refreshToken string, err error) + funcRegisterOrigin string + inspectFuncRegister func(ctx context.Context, email string, password string, name string, phone string, inviteCode int64, ip string, userAgent string) + afterRegisterCounter uint64 + beforeRegisterCounter uint64 + RegisterMock mAuthServiceMockRegister + funcValidate func(ctx context.Context, accessToken string) (i1 int, err error) funcValidateOrigin string inspectFuncValidate func(ctx context.Context, accessToken string) @@ -64,6 +71,9 @@ func NewAuthServiceMock(t minimock.Tester) *AuthServiceMock { m.RefreshMock = mAuthServiceMockRefresh{mock: m} m.RefreshMock.callArgs = []*AuthServiceMockRefreshParams{} + m.RegisterMock = mAuthServiceMockRegister{mock: m} + m.RegisterMock.callArgs = []*AuthServiceMockRegisterParams{} + m.ValidateMock = mAuthServiceMockValidate{mock: m} m.ValidateMock.callArgs = []*AuthServiceMockValidateParams{} @@ -1194,6 +1204,536 @@ func (m *AuthServiceMock) MinimockRefreshInspect() { } } +type mAuthServiceMockRegister struct { + optional bool + mock *AuthServiceMock + defaultExpectation *AuthServiceMockRegisterExpectation + expectations []*AuthServiceMockRegisterExpectation + + callArgs []*AuthServiceMockRegisterParams + mutex sync.RWMutex + + expectedInvocations uint64 + expectedInvocationsOrigin string +} + +// AuthServiceMockRegisterExpectation specifies expectation struct of the AuthService.Register +type AuthServiceMockRegisterExpectation struct { + mock *AuthServiceMock + params *AuthServiceMockRegisterParams + paramPtrs *AuthServiceMockRegisterParamPtrs + expectationOrigins AuthServiceMockRegisterExpectationOrigins + results *AuthServiceMockRegisterResults + returnOrigin string + Counter uint64 +} + +// AuthServiceMockRegisterParams contains parameters of the AuthService.Register +type AuthServiceMockRegisterParams struct { + ctx context.Context + email string + password string + name string + phone string + inviteCode int64 + ip string + userAgent string +} + +// AuthServiceMockRegisterParamPtrs contains pointers to parameters of the AuthService.Register +type AuthServiceMockRegisterParamPtrs struct { + ctx *context.Context + email *string + password *string + name *string + phone *string + inviteCode *int64 + ip *string + userAgent *string +} + +// AuthServiceMockRegisterResults contains results of the AuthService.Register +type AuthServiceMockRegisterResults struct { + accessToken string + refreshToken string + err error +} + +// AuthServiceMockRegisterOrigins contains origins of expectations of the AuthService.Register +type AuthServiceMockRegisterExpectationOrigins struct { + origin string + originCtx string + originEmail string + originPassword string + originName string + originPhone string + originInviteCode string + originIp string + originUserAgent string +} + +// Marks this method to be optional. The default behavior of any method with Return() is '1 or more', meaning +// the test will fail minimock's automatic final call check if the mocked method was not called at least once. +// Optional() makes method check to work in '0 or more' mode. +// It is NOT RECOMMENDED to use this option unless you really need it, as default behaviour helps to +// catch the problems when the expected method call is totally skipped during test run. +func (mmRegister *mAuthServiceMockRegister) Optional() *mAuthServiceMockRegister { + mmRegister.optional = true + return mmRegister +} + +// Expect sets up expected params for AuthService.Register +func (mmRegister *mAuthServiceMockRegister) Expect(ctx context.Context, email string, password string, name string, phone string, inviteCode int64, ip string, userAgent string) *mAuthServiceMockRegister { + if mmRegister.mock.funcRegister != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Set") + } + + if mmRegister.defaultExpectation == nil { + mmRegister.defaultExpectation = &AuthServiceMockRegisterExpectation{} + } + + if mmRegister.defaultExpectation.paramPtrs != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by ExpectParams functions") + } + + mmRegister.defaultExpectation.params = &AuthServiceMockRegisterParams{ctx, email, password, name, phone, inviteCode, ip, userAgent} + mmRegister.defaultExpectation.expectationOrigins.origin = minimock.CallerInfo(1) + for _, e := range mmRegister.expectations { + if minimock.Equal(e.params, mmRegister.defaultExpectation.params) { + mmRegister.mock.t.Fatalf("Expectation set by When has same params: %#v", *mmRegister.defaultExpectation.params) + } + } + + return mmRegister +} + +// ExpectCtxParam1 sets up expected param ctx for AuthService.Register +func (mmRegister *mAuthServiceMockRegister) ExpectCtxParam1(ctx context.Context) *mAuthServiceMockRegister { + if mmRegister.mock.funcRegister != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Set") + } + + if mmRegister.defaultExpectation == nil { + mmRegister.defaultExpectation = &AuthServiceMockRegisterExpectation{} + } + + if mmRegister.defaultExpectation.params != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Expect") + } + + if mmRegister.defaultExpectation.paramPtrs == nil { + mmRegister.defaultExpectation.paramPtrs = &AuthServiceMockRegisterParamPtrs{} + } + mmRegister.defaultExpectation.paramPtrs.ctx = &ctx + mmRegister.defaultExpectation.expectationOrigins.originCtx = minimock.CallerInfo(1) + + return mmRegister +} + +// ExpectEmailParam2 sets up expected param email for AuthService.Register +func (mmRegister *mAuthServiceMockRegister) ExpectEmailParam2(email string) *mAuthServiceMockRegister { + if mmRegister.mock.funcRegister != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Set") + } + + if mmRegister.defaultExpectation == nil { + mmRegister.defaultExpectation = &AuthServiceMockRegisterExpectation{} + } + + if mmRegister.defaultExpectation.params != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Expect") + } + + if mmRegister.defaultExpectation.paramPtrs == nil { + mmRegister.defaultExpectation.paramPtrs = &AuthServiceMockRegisterParamPtrs{} + } + mmRegister.defaultExpectation.paramPtrs.email = &email + mmRegister.defaultExpectation.expectationOrigins.originEmail = minimock.CallerInfo(1) + + return mmRegister +} + +// ExpectPasswordParam3 sets up expected param password for AuthService.Register +func (mmRegister *mAuthServiceMockRegister) ExpectPasswordParam3(password string) *mAuthServiceMockRegister { + if mmRegister.mock.funcRegister != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Set") + } + + if mmRegister.defaultExpectation == nil { + mmRegister.defaultExpectation = &AuthServiceMockRegisterExpectation{} + } + + if mmRegister.defaultExpectation.params != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Expect") + } + + if mmRegister.defaultExpectation.paramPtrs == nil { + mmRegister.defaultExpectation.paramPtrs = &AuthServiceMockRegisterParamPtrs{} + } + mmRegister.defaultExpectation.paramPtrs.password = &password + mmRegister.defaultExpectation.expectationOrigins.originPassword = minimock.CallerInfo(1) + + return mmRegister +} + +// ExpectNameParam4 sets up expected param name for AuthService.Register +func (mmRegister *mAuthServiceMockRegister) ExpectNameParam4(name string) *mAuthServiceMockRegister { + if mmRegister.mock.funcRegister != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Set") + } + + if mmRegister.defaultExpectation == nil { + mmRegister.defaultExpectation = &AuthServiceMockRegisterExpectation{} + } + + if mmRegister.defaultExpectation.params != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Expect") + } + + if mmRegister.defaultExpectation.paramPtrs == nil { + mmRegister.defaultExpectation.paramPtrs = &AuthServiceMockRegisterParamPtrs{} + } + mmRegister.defaultExpectation.paramPtrs.name = &name + mmRegister.defaultExpectation.expectationOrigins.originName = minimock.CallerInfo(1) + + return mmRegister +} + +// ExpectPhoneParam5 sets up expected param phone for AuthService.Register +func (mmRegister *mAuthServiceMockRegister) ExpectPhoneParam5(phone string) *mAuthServiceMockRegister { + if mmRegister.mock.funcRegister != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Set") + } + + if mmRegister.defaultExpectation == nil { + mmRegister.defaultExpectation = &AuthServiceMockRegisterExpectation{} + } + + if mmRegister.defaultExpectation.params != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Expect") + } + + if mmRegister.defaultExpectation.paramPtrs == nil { + mmRegister.defaultExpectation.paramPtrs = &AuthServiceMockRegisterParamPtrs{} + } + mmRegister.defaultExpectation.paramPtrs.phone = &phone + mmRegister.defaultExpectation.expectationOrigins.originPhone = minimock.CallerInfo(1) + + return mmRegister +} + +// ExpectInviteCodeParam6 sets up expected param inviteCode for AuthService.Register +func (mmRegister *mAuthServiceMockRegister) ExpectInviteCodeParam6(inviteCode int64) *mAuthServiceMockRegister { + if mmRegister.mock.funcRegister != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Set") + } + + if mmRegister.defaultExpectation == nil { + mmRegister.defaultExpectation = &AuthServiceMockRegisterExpectation{} + } + + if mmRegister.defaultExpectation.params != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Expect") + } + + if mmRegister.defaultExpectation.paramPtrs == nil { + mmRegister.defaultExpectation.paramPtrs = &AuthServiceMockRegisterParamPtrs{} + } + mmRegister.defaultExpectation.paramPtrs.inviteCode = &inviteCode + mmRegister.defaultExpectation.expectationOrigins.originInviteCode = minimock.CallerInfo(1) + + return mmRegister +} + +// ExpectIpParam7 sets up expected param ip for AuthService.Register +func (mmRegister *mAuthServiceMockRegister) ExpectIpParam7(ip string) *mAuthServiceMockRegister { + if mmRegister.mock.funcRegister != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Set") + } + + if mmRegister.defaultExpectation == nil { + mmRegister.defaultExpectation = &AuthServiceMockRegisterExpectation{} + } + + if mmRegister.defaultExpectation.params != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Expect") + } + + if mmRegister.defaultExpectation.paramPtrs == nil { + mmRegister.defaultExpectation.paramPtrs = &AuthServiceMockRegisterParamPtrs{} + } + mmRegister.defaultExpectation.paramPtrs.ip = &ip + mmRegister.defaultExpectation.expectationOrigins.originIp = minimock.CallerInfo(1) + + return mmRegister +} + +// ExpectUserAgentParam8 sets up expected param userAgent for AuthService.Register +func (mmRegister *mAuthServiceMockRegister) ExpectUserAgentParam8(userAgent string) *mAuthServiceMockRegister { + if mmRegister.mock.funcRegister != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Set") + } + + if mmRegister.defaultExpectation == nil { + mmRegister.defaultExpectation = &AuthServiceMockRegisterExpectation{} + } + + if mmRegister.defaultExpectation.params != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Expect") + } + + if mmRegister.defaultExpectation.paramPtrs == nil { + mmRegister.defaultExpectation.paramPtrs = &AuthServiceMockRegisterParamPtrs{} + } + mmRegister.defaultExpectation.paramPtrs.userAgent = &userAgent + mmRegister.defaultExpectation.expectationOrigins.originUserAgent = minimock.CallerInfo(1) + + return mmRegister +} + +// Inspect accepts an inspector function that has same arguments as the AuthService.Register +func (mmRegister *mAuthServiceMockRegister) Inspect(f func(ctx context.Context, email string, password string, name string, phone string, inviteCode int64, ip string, userAgent string)) *mAuthServiceMockRegister { + if mmRegister.mock.inspectFuncRegister != nil { + mmRegister.mock.t.Fatalf("Inspect function is already set for AuthServiceMock.Register") + } + + mmRegister.mock.inspectFuncRegister = f + + return mmRegister +} + +// Return sets up results that will be returned by AuthService.Register +func (mmRegister *mAuthServiceMockRegister) Return(accessToken string, refreshToken string, err error) *AuthServiceMock { + if mmRegister.mock.funcRegister != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Set") + } + + if mmRegister.defaultExpectation == nil { + mmRegister.defaultExpectation = &AuthServiceMockRegisterExpectation{mock: mmRegister.mock} + } + mmRegister.defaultExpectation.results = &AuthServiceMockRegisterResults{accessToken, refreshToken, err} + mmRegister.defaultExpectation.returnOrigin = minimock.CallerInfo(1) + return mmRegister.mock +} + +// Set uses given function f to mock the AuthService.Register method +func (mmRegister *mAuthServiceMockRegister) Set(f func(ctx context.Context, email string, password string, name string, phone string, inviteCode int64, ip string, userAgent string) (accessToken string, refreshToken string, err error)) *AuthServiceMock { + if mmRegister.defaultExpectation != nil { + mmRegister.mock.t.Fatalf("Default expectation is already set for the AuthService.Register method") + } + + if len(mmRegister.expectations) > 0 { + mmRegister.mock.t.Fatalf("Some expectations are already set for the AuthService.Register method") + } + + mmRegister.mock.funcRegister = f + mmRegister.mock.funcRegisterOrigin = minimock.CallerInfo(1) + return mmRegister.mock +} + +// When sets expectation for the AuthService.Register which will trigger the result defined by the following +// Then helper +func (mmRegister *mAuthServiceMockRegister) When(ctx context.Context, email string, password string, name string, phone string, inviteCode int64, ip string, userAgent string) *AuthServiceMockRegisterExpectation { + if mmRegister.mock.funcRegister != nil { + mmRegister.mock.t.Fatalf("AuthServiceMock.Register mock is already set by Set") + } + + expectation := &AuthServiceMockRegisterExpectation{ + mock: mmRegister.mock, + params: &AuthServiceMockRegisterParams{ctx, email, password, name, phone, inviteCode, ip, userAgent}, + expectationOrigins: AuthServiceMockRegisterExpectationOrigins{origin: minimock.CallerInfo(1)}, + } + mmRegister.expectations = append(mmRegister.expectations, expectation) + return expectation +} + +// Then sets up AuthService.Register return parameters for the expectation previously defined by the When method +func (e *AuthServiceMockRegisterExpectation) Then(accessToken string, refreshToken string, err error) *AuthServiceMock { + e.results = &AuthServiceMockRegisterResults{accessToken, refreshToken, err} + return e.mock +} + +// Times sets number of times AuthService.Register should be invoked +func (mmRegister *mAuthServiceMockRegister) Times(n uint64) *mAuthServiceMockRegister { + if n == 0 { + mmRegister.mock.t.Fatalf("Times of AuthServiceMock.Register mock can not be zero") + } + mm_atomic.StoreUint64(&mmRegister.expectedInvocations, n) + mmRegister.expectedInvocationsOrigin = minimock.CallerInfo(1) + return mmRegister +} + +func (mmRegister *mAuthServiceMockRegister) invocationsDone() bool { + if len(mmRegister.expectations) == 0 && mmRegister.defaultExpectation == nil && mmRegister.mock.funcRegister == nil { + return true + } + + totalInvocations := mm_atomic.LoadUint64(&mmRegister.mock.afterRegisterCounter) + expectedInvocations := mm_atomic.LoadUint64(&mmRegister.expectedInvocations) + + return totalInvocations > 0 && (expectedInvocations == 0 || expectedInvocations == totalInvocations) +} + +// Register implements mm_service.AuthService +func (mmRegister *AuthServiceMock) Register(ctx context.Context, email string, password string, name string, phone string, inviteCode int64, ip string, userAgent string) (accessToken string, refreshToken string, err error) { + mm_atomic.AddUint64(&mmRegister.beforeRegisterCounter, 1) + defer mm_atomic.AddUint64(&mmRegister.afterRegisterCounter, 1) + + mmRegister.t.Helper() + + if mmRegister.inspectFuncRegister != nil { + mmRegister.inspectFuncRegister(ctx, email, password, name, phone, inviteCode, ip, userAgent) + } + + mm_params := AuthServiceMockRegisterParams{ctx, email, password, name, phone, inviteCode, ip, userAgent} + + // Record call args + mmRegister.RegisterMock.mutex.Lock() + mmRegister.RegisterMock.callArgs = append(mmRegister.RegisterMock.callArgs, &mm_params) + mmRegister.RegisterMock.mutex.Unlock() + + for _, e := range mmRegister.RegisterMock.expectations { + if minimock.Equal(*e.params, mm_params) { + mm_atomic.AddUint64(&e.Counter, 1) + return e.results.accessToken, e.results.refreshToken, e.results.err + } + } + + if mmRegister.RegisterMock.defaultExpectation != nil { + mm_atomic.AddUint64(&mmRegister.RegisterMock.defaultExpectation.Counter, 1) + mm_want := mmRegister.RegisterMock.defaultExpectation.params + mm_want_ptrs := mmRegister.RegisterMock.defaultExpectation.paramPtrs + + mm_got := AuthServiceMockRegisterParams{ctx, email, password, name, phone, inviteCode, ip, userAgent} + + if mm_want_ptrs != nil { + + if mm_want_ptrs.ctx != nil && !minimock.Equal(*mm_want_ptrs.ctx, mm_got.ctx) { + mmRegister.t.Errorf("AuthServiceMock.Register got unexpected parameter ctx, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmRegister.RegisterMock.defaultExpectation.expectationOrigins.originCtx, *mm_want_ptrs.ctx, mm_got.ctx, minimock.Diff(*mm_want_ptrs.ctx, mm_got.ctx)) + } + + if mm_want_ptrs.email != nil && !minimock.Equal(*mm_want_ptrs.email, mm_got.email) { + mmRegister.t.Errorf("AuthServiceMock.Register got unexpected parameter email, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmRegister.RegisterMock.defaultExpectation.expectationOrigins.originEmail, *mm_want_ptrs.email, mm_got.email, minimock.Diff(*mm_want_ptrs.email, mm_got.email)) + } + + if mm_want_ptrs.password != nil && !minimock.Equal(*mm_want_ptrs.password, mm_got.password) { + mmRegister.t.Errorf("AuthServiceMock.Register got unexpected parameter password, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmRegister.RegisterMock.defaultExpectation.expectationOrigins.originPassword, *mm_want_ptrs.password, mm_got.password, minimock.Diff(*mm_want_ptrs.password, mm_got.password)) + } + + if mm_want_ptrs.name != nil && !minimock.Equal(*mm_want_ptrs.name, mm_got.name) { + mmRegister.t.Errorf("AuthServiceMock.Register got unexpected parameter name, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmRegister.RegisterMock.defaultExpectation.expectationOrigins.originName, *mm_want_ptrs.name, mm_got.name, minimock.Diff(*mm_want_ptrs.name, mm_got.name)) + } + + if mm_want_ptrs.phone != nil && !minimock.Equal(*mm_want_ptrs.phone, mm_got.phone) { + mmRegister.t.Errorf("AuthServiceMock.Register got unexpected parameter phone, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmRegister.RegisterMock.defaultExpectation.expectationOrigins.originPhone, *mm_want_ptrs.phone, mm_got.phone, minimock.Diff(*mm_want_ptrs.phone, mm_got.phone)) + } + + if mm_want_ptrs.inviteCode != nil && !minimock.Equal(*mm_want_ptrs.inviteCode, mm_got.inviteCode) { + mmRegister.t.Errorf("AuthServiceMock.Register got unexpected parameter inviteCode, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmRegister.RegisterMock.defaultExpectation.expectationOrigins.originInviteCode, *mm_want_ptrs.inviteCode, mm_got.inviteCode, minimock.Diff(*mm_want_ptrs.inviteCode, mm_got.inviteCode)) + } + + if mm_want_ptrs.ip != nil && !minimock.Equal(*mm_want_ptrs.ip, mm_got.ip) { + mmRegister.t.Errorf("AuthServiceMock.Register got unexpected parameter ip, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmRegister.RegisterMock.defaultExpectation.expectationOrigins.originIp, *mm_want_ptrs.ip, mm_got.ip, minimock.Diff(*mm_want_ptrs.ip, mm_got.ip)) + } + + if mm_want_ptrs.userAgent != nil && !minimock.Equal(*mm_want_ptrs.userAgent, mm_got.userAgent) { + mmRegister.t.Errorf("AuthServiceMock.Register got unexpected parameter userAgent, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmRegister.RegisterMock.defaultExpectation.expectationOrigins.originUserAgent, *mm_want_ptrs.userAgent, mm_got.userAgent, minimock.Diff(*mm_want_ptrs.userAgent, mm_got.userAgent)) + } + + } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { + mmRegister.t.Errorf("AuthServiceMock.Register got unexpected parameters, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmRegister.RegisterMock.defaultExpectation.expectationOrigins.origin, *mm_want, mm_got, minimock.Diff(*mm_want, mm_got)) + } + + mm_results := mmRegister.RegisterMock.defaultExpectation.results + if mm_results == nil { + mmRegister.t.Fatal("No results are set for the AuthServiceMock.Register") + } + return (*mm_results).accessToken, (*mm_results).refreshToken, (*mm_results).err + } + if mmRegister.funcRegister != nil { + return mmRegister.funcRegister(ctx, email, password, name, phone, inviteCode, ip, userAgent) + } + mmRegister.t.Fatalf("Unexpected call to AuthServiceMock.Register. %v %v %v %v %v %v %v %v", ctx, email, password, name, phone, inviteCode, ip, userAgent) + return +} + +// RegisterAfterCounter returns a count of finished AuthServiceMock.Register invocations +func (mmRegister *AuthServiceMock) RegisterAfterCounter() uint64 { + return mm_atomic.LoadUint64(&mmRegister.afterRegisterCounter) +} + +// RegisterBeforeCounter returns a count of AuthServiceMock.Register invocations +func (mmRegister *AuthServiceMock) RegisterBeforeCounter() uint64 { + return mm_atomic.LoadUint64(&mmRegister.beforeRegisterCounter) +} + +// Calls returns a list of arguments used in each call to AuthServiceMock.Register. +// The list is in the same order as the calls were made (i.e. recent calls have a higher index) +func (mmRegister *mAuthServiceMockRegister) Calls() []*AuthServiceMockRegisterParams { + mmRegister.mutex.RLock() + + argCopy := make([]*AuthServiceMockRegisterParams, len(mmRegister.callArgs)) + copy(argCopy, mmRegister.callArgs) + + mmRegister.mutex.RUnlock() + + return argCopy +} + +// MinimockRegisterDone returns true if the count of the Register invocations corresponds +// the number of defined expectations +func (m *AuthServiceMock) MinimockRegisterDone() bool { + if m.RegisterMock.optional { + // Optional methods provide '0 or more' call count restriction. + return true + } + + for _, e := range m.RegisterMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + return false + } + } + + return m.RegisterMock.invocationsDone() +} + +// MinimockRegisterInspect logs each unmet expectation +func (m *AuthServiceMock) MinimockRegisterInspect() { + for _, e := range m.RegisterMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + m.t.Errorf("Expected call to AuthServiceMock.Register at\n%s with params: %#v", e.expectationOrigins.origin, *e.params) + } + } + + afterRegisterCounter := mm_atomic.LoadUint64(&m.afterRegisterCounter) + // if default expectation was set then invocations count should be greater than zero + if m.RegisterMock.defaultExpectation != nil && afterRegisterCounter < 1 { + if m.RegisterMock.defaultExpectation.params == nil { + m.t.Errorf("Expected call to AuthServiceMock.Register at\n%s", m.RegisterMock.defaultExpectation.returnOrigin) + } else { + m.t.Errorf("Expected call to AuthServiceMock.Register at\n%s with params: %#v", m.RegisterMock.defaultExpectation.expectationOrigins.origin, *m.RegisterMock.defaultExpectation.params) + } + } + // if func was set then invocations count should be greater than zero + if m.funcRegister != nil && afterRegisterCounter < 1 { + m.t.Errorf("Expected call to AuthServiceMock.Register at\n%s", m.funcRegisterOrigin) + } + + if !m.RegisterMock.invocationsDone() && afterRegisterCounter > 0 { + m.t.Errorf("Expected %d calls to AuthServiceMock.Register at\n%s but found %d calls", + mm_atomic.LoadUint64(&m.RegisterMock.expectedInvocations), m.RegisterMock.expectedInvocationsOrigin, afterRegisterCounter) + } +} + type mAuthServiceMockValidate struct { optional bool mock *AuthServiceMock @@ -1547,6 +2087,8 @@ func (m *AuthServiceMock) MinimockFinish() { m.MinimockRefreshInspect() + m.MinimockRegisterInspect() + m.MinimockValidateInspect() } }) @@ -1574,5 +2116,6 @@ func (m *AuthServiceMock) minimockDone() bool { m.MinimockLoginDone() && m.MinimockLogoutDone() && m.MinimockRefreshDone() && + m.MinimockRegisterDone() && m.MinimockValidateDone() } diff --git a/internal/mocks/invite_repository_mock.go b/internal/mocks/invite_repository_mock.go index 70cfb51..268d4e5 100644 --- a/internal/mocks/invite_repository_mock.go +++ b/internal/mocks/invite_repository_mock.go @@ -41,6 +41,20 @@ type InviteRepositoryMock struct { beforeDeactivateExpiredCounter uint64 DeactivateExpiredMock mInviteRepositoryMockDeactivateExpired + funcDecrementCanBeUsedCountTx func(ctx context.Context, tx pgx.Tx, code int64) (err error) + funcDecrementCanBeUsedCountTxOrigin string + inspectFuncDecrementCanBeUsedCountTx func(ctx context.Context, tx pgx.Tx, code int64) + afterDecrementCanBeUsedCountTxCounter uint64 + beforeDecrementCanBeUsedCountTxCounter uint64 + DecrementCanBeUsedCountTxMock mInviteRepositoryMockDecrementCanBeUsedCountTx + + funcFindActiveByCode func(ctx context.Context, code int64) (ip1 *model.InviteCode, err error) + funcFindActiveByCodeOrigin string + inspectFuncFindActiveByCode func(ctx context.Context, code int64) + afterFindActiveByCodeCounter uint64 + beforeFindActiveByCodeCounter uint64 + FindActiveByCodeMock mInviteRepositoryMockFindActiveByCode + funcFindByCode func(ctx context.Context, code int64) (ip1 *model.InviteCode, err error) funcFindByCodeOrigin string inspectFuncFindByCode func(ctx context.Context, code int64) @@ -80,6 +94,12 @@ func NewInviteRepositoryMock(t minimock.Tester) *InviteRepositoryMock { m.DeactivateExpiredMock = mInviteRepositoryMockDeactivateExpired{mock: m} m.DeactivateExpiredMock.callArgs = []*InviteRepositoryMockDeactivateExpiredParams{} + m.DecrementCanBeUsedCountTxMock = mInviteRepositoryMockDecrementCanBeUsedCountTx{mock: m} + m.DecrementCanBeUsedCountTxMock.callArgs = []*InviteRepositoryMockDecrementCanBeUsedCountTxParams{} + + m.FindActiveByCodeMock = mInviteRepositoryMockFindActiveByCode{mock: m} + m.FindActiveByCodeMock.callArgs = []*InviteRepositoryMockFindActiveByCodeParams{} + m.FindByCodeMock = mInviteRepositoryMockFindByCode{mock: m} m.FindByCodeMock.callArgs = []*InviteRepositoryMockFindByCodeParams{} @@ -1121,6 +1141,722 @@ func (m *InviteRepositoryMock) MinimockDeactivateExpiredInspect() { } } +type mInviteRepositoryMockDecrementCanBeUsedCountTx struct { + optional bool + mock *InviteRepositoryMock + defaultExpectation *InviteRepositoryMockDecrementCanBeUsedCountTxExpectation + expectations []*InviteRepositoryMockDecrementCanBeUsedCountTxExpectation + + callArgs []*InviteRepositoryMockDecrementCanBeUsedCountTxParams + mutex sync.RWMutex + + expectedInvocations uint64 + expectedInvocationsOrigin string +} + +// InviteRepositoryMockDecrementCanBeUsedCountTxExpectation specifies expectation struct of the InviteRepository.DecrementCanBeUsedCountTx +type InviteRepositoryMockDecrementCanBeUsedCountTxExpectation struct { + mock *InviteRepositoryMock + params *InviteRepositoryMockDecrementCanBeUsedCountTxParams + paramPtrs *InviteRepositoryMockDecrementCanBeUsedCountTxParamPtrs + expectationOrigins InviteRepositoryMockDecrementCanBeUsedCountTxExpectationOrigins + results *InviteRepositoryMockDecrementCanBeUsedCountTxResults + returnOrigin string + Counter uint64 +} + +// InviteRepositoryMockDecrementCanBeUsedCountTxParams contains parameters of the InviteRepository.DecrementCanBeUsedCountTx +type InviteRepositoryMockDecrementCanBeUsedCountTxParams struct { + ctx context.Context + tx pgx.Tx + code int64 +} + +// InviteRepositoryMockDecrementCanBeUsedCountTxParamPtrs contains pointers to parameters of the InviteRepository.DecrementCanBeUsedCountTx +type InviteRepositoryMockDecrementCanBeUsedCountTxParamPtrs struct { + ctx *context.Context + tx *pgx.Tx + code *int64 +} + +// InviteRepositoryMockDecrementCanBeUsedCountTxResults contains results of the InviteRepository.DecrementCanBeUsedCountTx +type InviteRepositoryMockDecrementCanBeUsedCountTxResults struct { + err error +} + +// InviteRepositoryMockDecrementCanBeUsedCountTxOrigins contains origins of expectations of the InviteRepository.DecrementCanBeUsedCountTx +type InviteRepositoryMockDecrementCanBeUsedCountTxExpectationOrigins struct { + origin string + originCtx string + originTx string + originCode string +} + +// Marks this method to be optional. The default behavior of any method with Return() is '1 or more', meaning +// the test will fail minimock's automatic final call check if the mocked method was not called at least once. +// Optional() makes method check to work in '0 or more' mode. +// It is NOT RECOMMENDED to use this option unless you really need it, as default behaviour helps to +// catch the problems when the expected method call is totally skipped during test run. +func (mmDecrementCanBeUsedCountTx *mInviteRepositoryMockDecrementCanBeUsedCountTx) Optional() *mInviteRepositoryMockDecrementCanBeUsedCountTx { + mmDecrementCanBeUsedCountTx.optional = true + return mmDecrementCanBeUsedCountTx +} + +// Expect sets up expected params for InviteRepository.DecrementCanBeUsedCountTx +func (mmDecrementCanBeUsedCountTx *mInviteRepositoryMockDecrementCanBeUsedCountTx) Expect(ctx context.Context, tx pgx.Tx, code int64) *mInviteRepositoryMockDecrementCanBeUsedCountTx { + if mmDecrementCanBeUsedCountTx.mock.funcDecrementCanBeUsedCountTx != nil { + mmDecrementCanBeUsedCountTx.mock.t.Fatalf("InviteRepositoryMock.DecrementCanBeUsedCountTx mock is already set by Set") + } + + if mmDecrementCanBeUsedCountTx.defaultExpectation == nil { + mmDecrementCanBeUsedCountTx.defaultExpectation = &InviteRepositoryMockDecrementCanBeUsedCountTxExpectation{} + } + + if mmDecrementCanBeUsedCountTx.defaultExpectation.paramPtrs != nil { + mmDecrementCanBeUsedCountTx.mock.t.Fatalf("InviteRepositoryMock.DecrementCanBeUsedCountTx mock is already set by ExpectParams functions") + } + + mmDecrementCanBeUsedCountTx.defaultExpectation.params = &InviteRepositoryMockDecrementCanBeUsedCountTxParams{ctx, tx, code} + mmDecrementCanBeUsedCountTx.defaultExpectation.expectationOrigins.origin = minimock.CallerInfo(1) + for _, e := range mmDecrementCanBeUsedCountTx.expectations { + if minimock.Equal(e.params, mmDecrementCanBeUsedCountTx.defaultExpectation.params) { + mmDecrementCanBeUsedCountTx.mock.t.Fatalf("Expectation set by When has same params: %#v", *mmDecrementCanBeUsedCountTx.defaultExpectation.params) + } + } + + return mmDecrementCanBeUsedCountTx +} + +// ExpectCtxParam1 sets up expected param ctx for InviteRepository.DecrementCanBeUsedCountTx +func (mmDecrementCanBeUsedCountTx *mInviteRepositoryMockDecrementCanBeUsedCountTx) ExpectCtxParam1(ctx context.Context) *mInviteRepositoryMockDecrementCanBeUsedCountTx { + if mmDecrementCanBeUsedCountTx.mock.funcDecrementCanBeUsedCountTx != nil { + mmDecrementCanBeUsedCountTx.mock.t.Fatalf("InviteRepositoryMock.DecrementCanBeUsedCountTx mock is already set by Set") + } + + if mmDecrementCanBeUsedCountTx.defaultExpectation == nil { + mmDecrementCanBeUsedCountTx.defaultExpectation = &InviteRepositoryMockDecrementCanBeUsedCountTxExpectation{} + } + + if mmDecrementCanBeUsedCountTx.defaultExpectation.params != nil { + mmDecrementCanBeUsedCountTx.mock.t.Fatalf("InviteRepositoryMock.DecrementCanBeUsedCountTx mock is already set by Expect") + } + + if mmDecrementCanBeUsedCountTx.defaultExpectation.paramPtrs == nil { + mmDecrementCanBeUsedCountTx.defaultExpectation.paramPtrs = &InviteRepositoryMockDecrementCanBeUsedCountTxParamPtrs{} + } + mmDecrementCanBeUsedCountTx.defaultExpectation.paramPtrs.ctx = &ctx + mmDecrementCanBeUsedCountTx.defaultExpectation.expectationOrigins.originCtx = minimock.CallerInfo(1) + + return mmDecrementCanBeUsedCountTx +} + +// ExpectTxParam2 sets up expected param tx for InviteRepository.DecrementCanBeUsedCountTx +func (mmDecrementCanBeUsedCountTx *mInviteRepositoryMockDecrementCanBeUsedCountTx) ExpectTxParam2(tx pgx.Tx) *mInviteRepositoryMockDecrementCanBeUsedCountTx { + if mmDecrementCanBeUsedCountTx.mock.funcDecrementCanBeUsedCountTx != nil { + mmDecrementCanBeUsedCountTx.mock.t.Fatalf("InviteRepositoryMock.DecrementCanBeUsedCountTx mock is already set by Set") + } + + if mmDecrementCanBeUsedCountTx.defaultExpectation == nil { + mmDecrementCanBeUsedCountTx.defaultExpectation = &InviteRepositoryMockDecrementCanBeUsedCountTxExpectation{} + } + + if mmDecrementCanBeUsedCountTx.defaultExpectation.params != nil { + mmDecrementCanBeUsedCountTx.mock.t.Fatalf("InviteRepositoryMock.DecrementCanBeUsedCountTx mock is already set by Expect") + } + + if mmDecrementCanBeUsedCountTx.defaultExpectation.paramPtrs == nil { + mmDecrementCanBeUsedCountTx.defaultExpectation.paramPtrs = &InviteRepositoryMockDecrementCanBeUsedCountTxParamPtrs{} + } + mmDecrementCanBeUsedCountTx.defaultExpectation.paramPtrs.tx = &tx + mmDecrementCanBeUsedCountTx.defaultExpectation.expectationOrigins.originTx = minimock.CallerInfo(1) + + return mmDecrementCanBeUsedCountTx +} + +// ExpectCodeParam3 sets up expected param code for InviteRepository.DecrementCanBeUsedCountTx +func (mmDecrementCanBeUsedCountTx *mInviteRepositoryMockDecrementCanBeUsedCountTx) ExpectCodeParam3(code int64) *mInviteRepositoryMockDecrementCanBeUsedCountTx { + if mmDecrementCanBeUsedCountTx.mock.funcDecrementCanBeUsedCountTx != nil { + mmDecrementCanBeUsedCountTx.mock.t.Fatalf("InviteRepositoryMock.DecrementCanBeUsedCountTx mock is already set by Set") + } + + if mmDecrementCanBeUsedCountTx.defaultExpectation == nil { + mmDecrementCanBeUsedCountTx.defaultExpectation = &InviteRepositoryMockDecrementCanBeUsedCountTxExpectation{} + } + + if mmDecrementCanBeUsedCountTx.defaultExpectation.params != nil { + mmDecrementCanBeUsedCountTx.mock.t.Fatalf("InviteRepositoryMock.DecrementCanBeUsedCountTx mock is already set by Expect") + } + + if mmDecrementCanBeUsedCountTx.defaultExpectation.paramPtrs == nil { + mmDecrementCanBeUsedCountTx.defaultExpectation.paramPtrs = &InviteRepositoryMockDecrementCanBeUsedCountTxParamPtrs{} + } + mmDecrementCanBeUsedCountTx.defaultExpectation.paramPtrs.code = &code + mmDecrementCanBeUsedCountTx.defaultExpectation.expectationOrigins.originCode = minimock.CallerInfo(1) + + return mmDecrementCanBeUsedCountTx +} + +// Inspect accepts an inspector function that has same arguments as the InviteRepository.DecrementCanBeUsedCountTx +func (mmDecrementCanBeUsedCountTx *mInviteRepositoryMockDecrementCanBeUsedCountTx) Inspect(f func(ctx context.Context, tx pgx.Tx, code int64)) *mInviteRepositoryMockDecrementCanBeUsedCountTx { + if mmDecrementCanBeUsedCountTx.mock.inspectFuncDecrementCanBeUsedCountTx != nil { + mmDecrementCanBeUsedCountTx.mock.t.Fatalf("Inspect function is already set for InviteRepositoryMock.DecrementCanBeUsedCountTx") + } + + mmDecrementCanBeUsedCountTx.mock.inspectFuncDecrementCanBeUsedCountTx = f + + return mmDecrementCanBeUsedCountTx +} + +// Return sets up results that will be returned by InviteRepository.DecrementCanBeUsedCountTx +func (mmDecrementCanBeUsedCountTx *mInviteRepositoryMockDecrementCanBeUsedCountTx) Return(err error) *InviteRepositoryMock { + if mmDecrementCanBeUsedCountTx.mock.funcDecrementCanBeUsedCountTx != nil { + mmDecrementCanBeUsedCountTx.mock.t.Fatalf("InviteRepositoryMock.DecrementCanBeUsedCountTx mock is already set by Set") + } + + if mmDecrementCanBeUsedCountTx.defaultExpectation == nil { + mmDecrementCanBeUsedCountTx.defaultExpectation = &InviteRepositoryMockDecrementCanBeUsedCountTxExpectation{mock: mmDecrementCanBeUsedCountTx.mock} + } + mmDecrementCanBeUsedCountTx.defaultExpectation.results = &InviteRepositoryMockDecrementCanBeUsedCountTxResults{err} + mmDecrementCanBeUsedCountTx.defaultExpectation.returnOrigin = minimock.CallerInfo(1) + return mmDecrementCanBeUsedCountTx.mock +} + +// Set uses given function f to mock the InviteRepository.DecrementCanBeUsedCountTx method +func (mmDecrementCanBeUsedCountTx *mInviteRepositoryMockDecrementCanBeUsedCountTx) Set(f func(ctx context.Context, tx pgx.Tx, code int64) (err error)) *InviteRepositoryMock { + if mmDecrementCanBeUsedCountTx.defaultExpectation != nil { + mmDecrementCanBeUsedCountTx.mock.t.Fatalf("Default expectation is already set for the InviteRepository.DecrementCanBeUsedCountTx method") + } + + if len(mmDecrementCanBeUsedCountTx.expectations) > 0 { + mmDecrementCanBeUsedCountTx.mock.t.Fatalf("Some expectations are already set for the InviteRepository.DecrementCanBeUsedCountTx method") + } + + mmDecrementCanBeUsedCountTx.mock.funcDecrementCanBeUsedCountTx = f + mmDecrementCanBeUsedCountTx.mock.funcDecrementCanBeUsedCountTxOrigin = minimock.CallerInfo(1) + return mmDecrementCanBeUsedCountTx.mock +} + +// When sets expectation for the InviteRepository.DecrementCanBeUsedCountTx which will trigger the result defined by the following +// Then helper +func (mmDecrementCanBeUsedCountTx *mInviteRepositoryMockDecrementCanBeUsedCountTx) When(ctx context.Context, tx pgx.Tx, code int64) *InviteRepositoryMockDecrementCanBeUsedCountTxExpectation { + if mmDecrementCanBeUsedCountTx.mock.funcDecrementCanBeUsedCountTx != nil { + mmDecrementCanBeUsedCountTx.mock.t.Fatalf("InviteRepositoryMock.DecrementCanBeUsedCountTx mock is already set by Set") + } + + expectation := &InviteRepositoryMockDecrementCanBeUsedCountTxExpectation{ + mock: mmDecrementCanBeUsedCountTx.mock, + params: &InviteRepositoryMockDecrementCanBeUsedCountTxParams{ctx, tx, code}, + expectationOrigins: InviteRepositoryMockDecrementCanBeUsedCountTxExpectationOrigins{origin: minimock.CallerInfo(1)}, + } + mmDecrementCanBeUsedCountTx.expectations = append(mmDecrementCanBeUsedCountTx.expectations, expectation) + return expectation +} + +// Then sets up InviteRepository.DecrementCanBeUsedCountTx return parameters for the expectation previously defined by the When method +func (e *InviteRepositoryMockDecrementCanBeUsedCountTxExpectation) Then(err error) *InviteRepositoryMock { + e.results = &InviteRepositoryMockDecrementCanBeUsedCountTxResults{err} + return e.mock +} + +// Times sets number of times InviteRepository.DecrementCanBeUsedCountTx should be invoked +func (mmDecrementCanBeUsedCountTx *mInviteRepositoryMockDecrementCanBeUsedCountTx) Times(n uint64) *mInviteRepositoryMockDecrementCanBeUsedCountTx { + if n == 0 { + mmDecrementCanBeUsedCountTx.mock.t.Fatalf("Times of InviteRepositoryMock.DecrementCanBeUsedCountTx mock can not be zero") + } + mm_atomic.StoreUint64(&mmDecrementCanBeUsedCountTx.expectedInvocations, n) + mmDecrementCanBeUsedCountTx.expectedInvocationsOrigin = minimock.CallerInfo(1) + return mmDecrementCanBeUsedCountTx +} + +func (mmDecrementCanBeUsedCountTx *mInviteRepositoryMockDecrementCanBeUsedCountTx) invocationsDone() bool { + if len(mmDecrementCanBeUsedCountTx.expectations) == 0 && mmDecrementCanBeUsedCountTx.defaultExpectation == nil && mmDecrementCanBeUsedCountTx.mock.funcDecrementCanBeUsedCountTx == nil { + return true + } + + totalInvocations := mm_atomic.LoadUint64(&mmDecrementCanBeUsedCountTx.mock.afterDecrementCanBeUsedCountTxCounter) + expectedInvocations := mm_atomic.LoadUint64(&mmDecrementCanBeUsedCountTx.expectedInvocations) + + return totalInvocations > 0 && (expectedInvocations == 0 || expectedInvocations == totalInvocations) +} + +// DecrementCanBeUsedCountTx implements mm_repository.InviteRepository +func (mmDecrementCanBeUsedCountTx *InviteRepositoryMock) DecrementCanBeUsedCountTx(ctx context.Context, tx pgx.Tx, code int64) (err error) { + mm_atomic.AddUint64(&mmDecrementCanBeUsedCountTx.beforeDecrementCanBeUsedCountTxCounter, 1) + defer mm_atomic.AddUint64(&mmDecrementCanBeUsedCountTx.afterDecrementCanBeUsedCountTxCounter, 1) + + mmDecrementCanBeUsedCountTx.t.Helper() + + if mmDecrementCanBeUsedCountTx.inspectFuncDecrementCanBeUsedCountTx != nil { + mmDecrementCanBeUsedCountTx.inspectFuncDecrementCanBeUsedCountTx(ctx, tx, code) + } + + mm_params := InviteRepositoryMockDecrementCanBeUsedCountTxParams{ctx, tx, code} + + // Record call args + mmDecrementCanBeUsedCountTx.DecrementCanBeUsedCountTxMock.mutex.Lock() + mmDecrementCanBeUsedCountTx.DecrementCanBeUsedCountTxMock.callArgs = append(mmDecrementCanBeUsedCountTx.DecrementCanBeUsedCountTxMock.callArgs, &mm_params) + mmDecrementCanBeUsedCountTx.DecrementCanBeUsedCountTxMock.mutex.Unlock() + + for _, e := range mmDecrementCanBeUsedCountTx.DecrementCanBeUsedCountTxMock.expectations { + if minimock.Equal(*e.params, mm_params) { + mm_atomic.AddUint64(&e.Counter, 1) + return e.results.err + } + } + + if mmDecrementCanBeUsedCountTx.DecrementCanBeUsedCountTxMock.defaultExpectation != nil { + mm_atomic.AddUint64(&mmDecrementCanBeUsedCountTx.DecrementCanBeUsedCountTxMock.defaultExpectation.Counter, 1) + mm_want := mmDecrementCanBeUsedCountTx.DecrementCanBeUsedCountTxMock.defaultExpectation.params + mm_want_ptrs := mmDecrementCanBeUsedCountTx.DecrementCanBeUsedCountTxMock.defaultExpectation.paramPtrs + + mm_got := InviteRepositoryMockDecrementCanBeUsedCountTxParams{ctx, tx, code} + + if mm_want_ptrs != nil { + + if mm_want_ptrs.ctx != nil && !minimock.Equal(*mm_want_ptrs.ctx, mm_got.ctx) { + mmDecrementCanBeUsedCountTx.t.Errorf("InviteRepositoryMock.DecrementCanBeUsedCountTx got unexpected parameter ctx, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmDecrementCanBeUsedCountTx.DecrementCanBeUsedCountTxMock.defaultExpectation.expectationOrigins.originCtx, *mm_want_ptrs.ctx, mm_got.ctx, minimock.Diff(*mm_want_ptrs.ctx, mm_got.ctx)) + } + + if mm_want_ptrs.tx != nil && !minimock.Equal(*mm_want_ptrs.tx, mm_got.tx) { + mmDecrementCanBeUsedCountTx.t.Errorf("InviteRepositoryMock.DecrementCanBeUsedCountTx got unexpected parameter tx, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmDecrementCanBeUsedCountTx.DecrementCanBeUsedCountTxMock.defaultExpectation.expectationOrigins.originTx, *mm_want_ptrs.tx, mm_got.tx, minimock.Diff(*mm_want_ptrs.tx, mm_got.tx)) + } + + if mm_want_ptrs.code != nil && !minimock.Equal(*mm_want_ptrs.code, mm_got.code) { + mmDecrementCanBeUsedCountTx.t.Errorf("InviteRepositoryMock.DecrementCanBeUsedCountTx got unexpected parameter code, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmDecrementCanBeUsedCountTx.DecrementCanBeUsedCountTxMock.defaultExpectation.expectationOrigins.originCode, *mm_want_ptrs.code, mm_got.code, minimock.Diff(*mm_want_ptrs.code, mm_got.code)) + } + + } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { + mmDecrementCanBeUsedCountTx.t.Errorf("InviteRepositoryMock.DecrementCanBeUsedCountTx got unexpected parameters, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmDecrementCanBeUsedCountTx.DecrementCanBeUsedCountTxMock.defaultExpectation.expectationOrigins.origin, *mm_want, mm_got, minimock.Diff(*mm_want, mm_got)) + } + + mm_results := mmDecrementCanBeUsedCountTx.DecrementCanBeUsedCountTxMock.defaultExpectation.results + if mm_results == nil { + mmDecrementCanBeUsedCountTx.t.Fatal("No results are set for the InviteRepositoryMock.DecrementCanBeUsedCountTx") + } + return (*mm_results).err + } + if mmDecrementCanBeUsedCountTx.funcDecrementCanBeUsedCountTx != nil { + return mmDecrementCanBeUsedCountTx.funcDecrementCanBeUsedCountTx(ctx, tx, code) + } + mmDecrementCanBeUsedCountTx.t.Fatalf("Unexpected call to InviteRepositoryMock.DecrementCanBeUsedCountTx. %v %v %v", ctx, tx, code) + return +} + +// DecrementCanBeUsedCountTxAfterCounter returns a count of finished InviteRepositoryMock.DecrementCanBeUsedCountTx invocations +func (mmDecrementCanBeUsedCountTx *InviteRepositoryMock) DecrementCanBeUsedCountTxAfterCounter() uint64 { + return mm_atomic.LoadUint64(&mmDecrementCanBeUsedCountTx.afterDecrementCanBeUsedCountTxCounter) +} + +// DecrementCanBeUsedCountTxBeforeCounter returns a count of InviteRepositoryMock.DecrementCanBeUsedCountTx invocations +func (mmDecrementCanBeUsedCountTx *InviteRepositoryMock) DecrementCanBeUsedCountTxBeforeCounter() uint64 { + return mm_atomic.LoadUint64(&mmDecrementCanBeUsedCountTx.beforeDecrementCanBeUsedCountTxCounter) +} + +// Calls returns a list of arguments used in each call to InviteRepositoryMock.DecrementCanBeUsedCountTx. +// The list is in the same order as the calls were made (i.e. recent calls have a higher index) +func (mmDecrementCanBeUsedCountTx *mInviteRepositoryMockDecrementCanBeUsedCountTx) Calls() []*InviteRepositoryMockDecrementCanBeUsedCountTxParams { + mmDecrementCanBeUsedCountTx.mutex.RLock() + + argCopy := make([]*InviteRepositoryMockDecrementCanBeUsedCountTxParams, len(mmDecrementCanBeUsedCountTx.callArgs)) + copy(argCopy, mmDecrementCanBeUsedCountTx.callArgs) + + mmDecrementCanBeUsedCountTx.mutex.RUnlock() + + return argCopy +} + +// MinimockDecrementCanBeUsedCountTxDone returns true if the count of the DecrementCanBeUsedCountTx invocations corresponds +// the number of defined expectations +func (m *InviteRepositoryMock) MinimockDecrementCanBeUsedCountTxDone() bool { + if m.DecrementCanBeUsedCountTxMock.optional { + // Optional methods provide '0 or more' call count restriction. + return true + } + + for _, e := range m.DecrementCanBeUsedCountTxMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + return false + } + } + + return m.DecrementCanBeUsedCountTxMock.invocationsDone() +} + +// MinimockDecrementCanBeUsedCountTxInspect logs each unmet expectation +func (m *InviteRepositoryMock) MinimockDecrementCanBeUsedCountTxInspect() { + for _, e := range m.DecrementCanBeUsedCountTxMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + m.t.Errorf("Expected call to InviteRepositoryMock.DecrementCanBeUsedCountTx at\n%s with params: %#v", e.expectationOrigins.origin, *e.params) + } + } + + afterDecrementCanBeUsedCountTxCounter := mm_atomic.LoadUint64(&m.afterDecrementCanBeUsedCountTxCounter) + // if default expectation was set then invocations count should be greater than zero + if m.DecrementCanBeUsedCountTxMock.defaultExpectation != nil && afterDecrementCanBeUsedCountTxCounter < 1 { + if m.DecrementCanBeUsedCountTxMock.defaultExpectation.params == nil { + m.t.Errorf("Expected call to InviteRepositoryMock.DecrementCanBeUsedCountTx at\n%s", m.DecrementCanBeUsedCountTxMock.defaultExpectation.returnOrigin) + } else { + m.t.Errorf("Expected call to InviteRepositoryMock.DecrementCanBeUsedCountTx at\n%s with params: %#v", m.DecrementCanBeUsedCountTxMock.defaultExpectation.expectationOrigins.origin, *m.DecrementCanBeUsedCountTxMock.defaultExpectation.params) + } + } + // if func was set then invocations count should be greater than zero + if m.funcDecrementCanBeUsedCountTx != nil && afterDecrementCanBeUsedCountTxCounter < 1 { + m.t.Errorf("Expected call to InviteRepositoryMock.DecrementCanBeUsedCountTx at\n%s", m.funcDecrementCanBeUsedCountTxOrigin) + } + + if !m.DecrementCanBeUsedCountTxMock.invocationsDone() && afterDecrementCanBeUsedCountTxCounter > 0 { + m.t.Errorf("Expected %d calls to InviteRepositoryMock.DecrementCanBeUsedCountTx at\n%s but found %d calls", + mm_atomic.LoadUint64(&m.DecrementCanBeUsedCountTxMock.expectedInvocations), m.DecrementCanBeUsedCountTxMock.expectedInvocationsOrigin, afterDecrementCanBeUsedCountTxCounter) + } +} + +type mInviteRepositoryMockFindActiveByCode struct { + optional bool + mock *InviteRepositoryMock + defaultExpectation *InviteRepositoryMockFindActiveByCodeExpectation + expectations []*InviteRepositoryMockFindActiveByCodeExpectation + + callArgs []*InviteRepositoryMockFindActiveByCodeParams + mutex sync.RWMutex + + expectedInvocations uint64 + expectedInvocationsOrigin string +} + +// InviteRepositoryMockFindActiveByCodeExpectation specifies expectation struct of the InviteRepository.FindActiveByCode +type InviteRepositoryMockFindActiveByCodeExpectation struct { + mock *InviteRepositoryMock + params *InviteRepositoryMockFindActiveByCodeParams + paramPtrs *InviteRepositoryMockFindActiveByCodeParamPtrs + expectationOrigins InviteRepositoryMockFindActiveByCodeExpectationOrigins + results *InviteRepositoryMockFindActiveByCodeResults + returnOrigin string + Counter uint64 +} + +// InviteRepositoryMockFindActiveByCodeParams contains parameters of the InviteRepository.FindActiveByCode +type InviteRepositoryMockFindActiveByCodeParams struct { + ctx context.Context + code int64 +} + +// InviteRepositoryMockFindActiveByCodeParamPtrs contains pointers to parameters of the InviteRepository.FindActiveByCode +type InviteRepositoryMockFindActiveByCodeParamPtrs struct { + ctx *context.Context + code *int64 +} + +// InviteRepositoryMockFindActiveByCodeResults contains results of the InviteRepository.FindActiveByCode +type InviteRepositoryMockFindActiveByCodeResults struct { + ip1 *model.InviteCode + err error +} + +// InviteRepositoryMockFindActiveByCodeOrigins contains origins of expectations of the InviteRepository.FindActiveByCode +type InviteRepositoryMockFindActiveByCodeExpectationOrigins struct { + origin string + originCtx string + originCode string +} + +// Marks this method to be optional. The default behavior of any method with Return() is '1 or more', meaning +// the test will fail minimock's automatic final call check if the mocked method was not called at least once. +// Optional() makes method check to work in '0 or more' mode. +// It is NOT RECOMMENDED to use this option unless you really need it, as default behaviour helps to +// catch the problems when the expected method call is totally skipped during test run. +func (mmFindActiveByCode *mInviteRepositoryMockFindActiveByCode) Optional() *mInviteRepositoryMockFindActiveByCode { + mmFindActiveByCode.optional = true + return mmFindActiveByCode +} + +// Expect sets up expected params for InviteRepository.FindActiveByCode +func (mmFindActiveByCode *mInviteRepositoryMockFindActiveByCode) Expect(ctx context.Context, code int64) *mInviteRepositoryMockFindActiveByCode { + if mmFindActiveByCode.mock.funcFindActiveByCode != nil { + mmFindActiveByCode.mock.t.Fatalf("InviteRepositoryMock.FindActiveByCode mock is already set by Set") + } + + if mmFindActiveByCode.defaultExpectation == nil { + mmFindActiveByCode.defaultExpectation = &InviteRepositoryMockFindActiveByCodeExpectation{} + } + + if mmFindActiveByCode.defaultExpectation.paramPtrs != nil { + mmFindActiveByCode.mock.t.Fatalf("InviteRepositoryMock.FindActiveByCode mock is already set by ExpectParams functions") + } + + mmFindActiveByCode.defaultExpectation.params = &InviteRepositoryMockFindActiveByCodeParams{ctx, code} + mmFindActiveByCode.defaultExpectation.expectationOrigins.origin = minimock.CallerInfo(1) + for _, e := range mmFindActiveByCode.expectations { + if minimock.Equal(e.params, mmFindActiveByCode.defaultExpectation.params) { + mmFindActiveByCode.mock.t.Fatalf("Expectation set by When has same params: %#v", *mmFindActiveByCode.defaultExpectation.params) + } + } + + return mmFindActiveByCode +} + +// ExpectCtxParam1 sets up expected param ctx for InviteRepository.FindActiveByCode +func (mmFindActiveByCode *mInviteRepositoryMockFindActiveByCode) ExpectCtxParam1(ctx context.Context) *mInviteRepositoryMockFindActiveByCode { + if mmFindActiveByCode.mock.funcFindActiveByCode != nil { + mmFindActiveByCode.mock.t.Fatalf("InviteRepositoryMock.FindActiveByCode mock is already set by Set") + } + + if mmFindActiveByCode.defaultExpectation == nil { + mmFindActiveByCode.defaultExpectation = &InviteRepositoryMockFindActiveByCodeExpectation{} + } + + if mmFindActiveByCode.defaultExpectation.params != nil { + mmFindActiveByCode.mock.t.Fatalf("InviteRepositoryMock.FindActiveByCode mock is already set by Expect") + } + + if mmFindActiveByCode.defaultExpectation.paramPtrs == nil { + mmFindActiveByCode.defaultExpectation.paramPtrs = &InviteRepositoryMockFindActiveByCodeParamPtrs{} + } + mmFindActiveByCode.defaultExpectation.paramPtrs.ctx = &ctx + mmFindActiveByCode.defaultExpectation.expectationOrigins.originCtx = minimock.CallerInfo(1) + + return mmFindActiveByCode +} + +// ExpectCodeParam2 sets up expected param code for InviteRepository.FindActiveByCode +func (mmFindActiveByCode *mInviteRepositoryMockFindActiveByCode) ExpectCodeParam2(code int64) *mInviteRepositoryMockFindActiveByCode { + if mmFindActiveByCode.mock.funcFindActiveByCode != nil { + mmFindActiveByCode.mock.t.Fatalf("InviteRepositoryMock.FindActiveByCode mock is already set by Set") + } + + if mmFindActiveByCode.defaultExpectation == nil { + mmFindActiveByCode.defaultExpectation = &InviteRepositoryMockFindActiveByCodeExpectation{} + } + + if mmFindActiveByCode.defaultExpectation.params != nil { + mmFindActiveByCode.mock.t.Fatalf("InviteRepositoryMock.FindActiveByCode mock is already set by Expect") + } + + if mmFindActiveByCode.defaultExpectation.paramPtrs == nil { + mmFindActiveByCode.defaultExpectation.paramPtrs = &InviteRepositoryMockFindActiveByCodeParamPtrs{} + } + mmFindActiveByCode.defaultExpectation.paramPtrs.code = &code + mmFindActiveByCode.defaultExpectation.expectationOrigins.originCode = minimock.CallerInfo(1) + + return mmFindActiveByCode +} + +// Inspect accepts an inspector function that has same arguments as the InviteRepository.FindActiveByCode +func (mmFindActiveByCode *mInviteRepositoryMockFindActiveByCode) Inspect(f func(ctx context.Context, code int64)) *mInviteRepositoryMockFindActiveByCode { + if mmFindActiveByCode.mock.inspectFuncFindActiveByCode != nil { + mmFindActiveByCode.mock.t.Fatalf("Inspect function is already set for InviteRepositoryMock.FindActiveByCode") + } + + mmFindActiveByCode.mock.inspectFuncFindActiveByCode = f + + return mmFindActiveByCode +} + +// Return sets up results that will be returned by InviteRepository.FindActiveByCode +func (mmFindActiveByCode *mInviteRepositoryMockFindActiveByCode) Return(ip1 *model.InviteCode, err error) *InviteRepositoryMock { + if mmFindActiveByCode.mock.funcFindActiveByCode != nil { + mmFindActiveByCode.mock.t.Fatalf("InviteRepositoryMock.FindActiveByCode mock is already set by Set") + } + + if mmFindActiveByCode.defaultExpectation == nil { + mmFindActiveByCode.defaultExpectation = &InviteRepositoryMockFindActiveByCodeExpectation{mock: mmFindActiveByCode.mock} + } + mmFindActiveByCode.defaultExpectation.results = &InviteRepositoryMockFindActiveByCodeResults{ip1, err} + mmFindActiveByCode.defaultExpectation.returnOrigin = minimock.CallerInfo(1) + return mmFindActiveByCode.mock +} + +// Set uses given function f to mock the InviteRepository.FindActiveByCode method +func (mmFindActiveByCode *mInviteRepositoryMockFindActiveByCode) Set(f func(ctx context.Context, code int64) (ip1 *model.InviteCode, err error)) *InviteRepositoryMock { + if mmFindActiveByCode.defaultExpectation != nil { + mmFindActiveByCode.mock.t.Fatalf("Default expectation is already set for the InviteRepository.FindActiveByCode method") + } + + if len(mmFindActiveByCode.expectations) > 0 { + mmFindActiveByCode.mock.t.Fatalf("Some expectations are already set for the InviteRepository.FindActiveByCode method") + } + + mmFindActiveByCode.mock.funcFindActiveByCode = f + mmFindActiveByCode.mock.funcFindActiveByCodeOrigin = minimock.CallerInfo(1) + return mmFindActiveByCode.mock +} + +// When sets expectation for the InviteRepository.FindActiveByCode which will trigger the result defined by the following +// Then helper +func (mmFindActiveByCode *mInviteRepositoryMockFindActiveByCode) When(ctx context.Context, code int64) *InviteRepositoryMockFindActiveByCodeExpectation { + if mmFindActiveByCode.mock.funcFindActiveByCode != nil { + mmFindActiveByCode.mock.t.Fatalf("InviteRepositoryMock.FindActiveByCode mock is already set by Set") + } + + expectation := &InviteRepositoryMockFindActiveByCodeExpectation{ + mock: mmFindActiveByCode.mock, + params: &InviteRepositoryMockFindActiveByCodeParams{ctx, code}, + expectationOrigins: InviteRepositoryMockFindActiveByCodeExpectationOrigins{origin: minimock.CallerInfo(1)}, + } + mmFindActiveByCode.expectations = append(mmFindActiveByCode.expectations, expectation) + return expectation +} + +// Then sets up InviteRepository.FindActiveByCode return parameters for the expectation previously defined by the When method +func (e *InviteRepositoryMockFindActiveByCodeExpectation) Then(ip1 *model.InviteCode, err error) *InviteRepositoryMock { + e.results = &InviteRepositoryMockFindActiveByCodeResults{ip1, err} + return e.mock +} + +// Times sets number of times InviteRepository.FindActiveByCode should be invoked +func (mmFindActiveByCode *mInviteRepositoryMockFindActiveByCode) Times(n uint64) *mInviteRepositoryMockFindActiveByCode { + if n == 0 { + mmFindActiveByCode.mock.t.Fatalf("Times of InviteRepositoryMock.FindActiveByCode mock can not be zero") + } + mm_atomic.StoreUint64(&mmFindActiveByCode.expectedInvocations, n) + mmFindActiveByCode.expectedInvocationsOrigin = minimock.CallerInfo(1) + return mmFindActiveByCode +} + +func (mmFindActiveByCode *mInviteRepositoryMockFindActiveByCode) invocationsDone() bool { + if len(mmFindActiveByCode.expectations) == 0 && mmFindActiveByCode.defaultExpectation == nil && mmFindActiveByCode.mock.funcFindActiveByCode == nil { + return true + } + + totalInvocations := mm_atomic.LoadUint64(&mmFindActiveByCode.mock.afterFindActiveByCodeCounter) + expectedInvocations := mm_atomic.LoadUint64(&mmFindActiveByCode.expectedInvocations) + + return totalInvocations > 0 && (expectedInvocations == 0 || expectedInvocations == totalInvocations) +} + +// FindActiveByCode implements mm_repository.InviteRepository +func (mmFindActiveByCode *InviteRepositoryMock) FindActiveByCode(ctx context.Context, code int64) (ip1 *model.InviteCode, err error) { + mm_atomic.AddUint64(&mmFindActiveByCode.beforeFindActiveByCodeCounter, 1) + defer mm_atomic.AddUint64(&mmFindActiveByCode.afterFindActiveByCodeCounter, 1) + + mmFindActiveByCode.t.Helper() + + if mmFindActiveByCode.inspectFuncFindActiveByCode != nil { + mmFindActiveByCode.inspectFuncFindActiveByCode(ctx, code) + } + + mm_params := InviteRepositoryMockFindActiveByCodeParams{ctx, code} + + // Record call args + mmFindActiveByCode.FindActiveByCodeMock.mutex.Lock() + mmFindActiveByCode.FindActiveByCodeMock.callArgs = append(mmFindActiveByCode.FindActiveByCodeMock.callArgs, &mm_params) + mmFindActiveByCode.FindActiveByCodeMock.mutex.Unlock() + + for _, e := range mmFindActiveByCode.FindActiveByCodeMock.expectations { + if minimock.Equal(*e.params, mm_params) { + mm_atomic.AddUint64(&e.Counter, 1) + return e.results.ip1, e.results.err + } + } + + if mmFindActiveByCode.FindActiveByCodeMock.defaultExpectation != nil { + mm_atomic.AddUint64(&mmFindActiveByCode.FindActiveByCodeMock.defaultExpectation.Counter, 1) + mm_want := mmFindActiveByCode.FindActiveByCodeMock.defaultExpectation.params + mm_want_ptrs := mmFindActiveByCode.FindActiveByCodeMock.defaultExpectation.paramPtrs + + mm_got := InviteRepositoryMockFindActiveByCodeParams{ctx, code} + + if mm_want_ptrs != nil { + + if mm_want_ptrs.ctx != nil && !minimock.Equal(*mm_want_ptrs.ctx, mm_got.ctx) { + mmFindActiveByCode.t.Errorf("InviteRepositoryMock.FindActiveByCode got unexpected parameter ctx, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmFindActiveByCode.FindActiveByCodeMock.defaultExpectation.expectationOrigins.originCtx, *mm_want_ptrs.ctx, mm_got.ctx, minimock.Diff(*mm_want_ptrs.ctx, mm_got.ctx)) + } + + if mm_want_ptrs.code != nil && !minimock.Equal(*mm_want_ptrs.code, mm_got.code) { + mmFindActiveByCode.t.Errorf("InviteRepositoryMock.FindActiveByCode got unexpected parameter code, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmFindActiveByCode.FindActiveByCodeMock.defaultExpectation.expectationOrigins.originCode, *mm_want_ptrs.code, mm_got.code, minimock.Diff(*mm_want_ptrs.code, mm_got.code)) + } + + } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { + mmFindActiveByCode.t.Errorf("InviteRepositoryMock.FindActiveByCode got unexpected parameters, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmFindActiveByCode.FindActiveByCodeMock.defaultExpectation.expectationOrigins.origin, *mm_want, mm_got, minimock.Diff(*mm_want, mm_got)) + } + + mm_results := mmFindActiveByCode.FindActiveByCodeMock.defaultExpectation.results + if mm_results == nil { + mmFindActiveByCode.t.Fatal("No results are set for the InviteRepositoryMock.FindActiveByCode") + } + return (*mm_results).ip1, (*mm_results).err + } + if mmFindActiveByCode.funcFindActiveByCode != nil { + return mmFindActiveByCode.funcFindActiveByCode(ctx, code) + } + mmFindActiveByCode.t.Fatalf("Unexpected call to InviteRepositoryMock.FindActiveByCode. %v %v", ctx, code) + return +} + +// FindActiveByCodeAfterCounter returns a count of finished InviteRepositoryMock.FindActiveByCode invocations +func (mmFindActiveByCode *InviteRepositoryMock) FindActiveByCodeAfterCounter() uint64 { + return mm_atomic.LoadUint64(&mmFindActiveByCode.afterFindActiveByCodeCounter) +} + +// FindActiveByCodeBeforeCounter returns a count of InviteRepositoryMock.FindActiveByCode invocations +func (mmFindActiveByCode *InviteRepositoryMock) FindActiveByCodeBeforeCounter() uint64 { + return mm_atomic.LoadUint64(&mmFindActiveByCode.beforeFindActiveByCodeCounter) +} + +// Calls returns a list of arguments used in each call to InviteRepositoryMock.FindActiveByCode. +// The list is in the same order as the calls were made (i.e. recent calls have a higher index) +func (mmFindActiveByCode *mInviteRepositoryMockFindActiveByCode) Calls() []*InviteRepositoryMockFindActiveByCodeParams { + mmFindActiveByCode.mutex.RLock() + + argCopy := make([]*InviteRepositoryMockFindActiveByCodeParams, len(mmFindActiveByCode.callArgs)) + copy(argCopy, mmFindActiveByCode.callArgs) + + mmFindActiveByCode.mutex.RUnlock() + + return argCopy +} + +// MinimockFindActiveByCodeDone returns true if the count of the FindActiveByCode invocations corresponds +// the number of defined expectations +func (m *InviteRepositoryMock) MinimockFindActiveByCodeDone() bool { + if m.FindActiveByCodeMock.optional { + // Optional methods provide '0 or more' call count restriction. + return true + } + + for _, e := range m.FindActiveByCodeMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + return false + } + } + + return m.FindActiveByCodeMock.invocationsDone() +} + +// MinimockFindActiveByCodeInspect logs each unmet expectation +func (m *InviteRepositoryMock) MinimockFindActiveByCodeInspect() { + for _, e := range m.FindActiveByCodeMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + m.t.Errorf("Expected call to InviteRepositoryMock.FindActiveByCode at\n%s with params: %#v", e.expectationOrigins.origin, *e.params) + } + } + + afterFindActiveByCodeCounter := mm_atomic.LoadUint64(&m.afterFindActiveByCodeCounter) + // if default expectation was set then invocations count should be greater than zero + if m.FindActiveByCodeMock.defaultExpectation != nil && afterFindActiveByCodeCounter < 1 { + if m.FindActiveByCodeMock.defaultExpectation.params == nil { + m.t.Errorf("Expected call to InviteRepositoryMock.FindActiveByCode at\n%s", m.FindActiveByCodeMock.defaultExpectation.returnOrigin) + } else { + m.t.Errorf("Expected call to InviteRepositoryMock.FindActiveByCode at\n%s with params: %#v", m.FindActiveByCodeMock.defaultExpectation.expectationOrigins.origin, *m.FindActiveByCodeMock.defaultExpectation.params) + } + } + // if func was set then invocations count should be greater than zero + if m.funcFindActiveByCode != nil && afterFindActiveByCodeCounter < 1 { + m.t.Errorf("Expected call to InviteRepositoryMock.FindActiveByCode at\n%s", m.funcFindActiveByCodeOrigin) + } + + if !m.FindActiveByCodeMock.invocationsDone() && afterFindActiveByCodeCounter > 0 { + m.t.Errorf("Expected %d calls to InviteRepositoryMock.FindActiveByCode at\n%s but found %d calls", + mm_atomic.LoadUint64(&m.FindActiveByCodeMock.expectedInvocations), m.FindActiveByCodeMock.expectedInvocationsOrigin, afterFindActiveByCodeCounter) + } +} + type mInviteRepositoryMockFindByCode struct { optional bool mock *InviteRepositoryMock @@ -2159,6 +2895,10 @@ func (m *InviteRepositoryMock) MinimockFinish() { m.MinimockDeactivateExpiredInspect() + m.MinimockDecrementCanBeUsedCountTxInspect() + + m.MinimockFindActiveByCodeInspect() + m.MinimockFindByCodeInspect() m.MinimockGetUserInvitesInspect() @@ -2190,6 +2930,8 @@ func (m *InviteRepositoryMock) minimockDone() bool { m.MinimockCreateDone() && m.MinimockCreateTxDone() && m.MinimockDeactivateExpiredDone() && + m.MinimockDecrementCanBeUsedCountTxDone() && + m.MinimockFindActiveByCodeDone() && m.MinimockFindByCodeDone() && m.MinimockGetUserInvitesDone() && m.MinimockIncrementUsedCountDone() diff --git a/internal/mocks/user_repository_mock.go b/internal/mocks/user_repository_mock.go index 0fa4f8f..21d77f5 100644 --- a/internal/mocks/user_repository_mock.go +++ b/internal/mocks/user_repository_mock.go @@ -41,6 +41,13 @@ type UserRepositoryMock struct { beforeCreateCounter uint64 CreateMock mUserRepositoryMockCreate + funcCreateTx func(ctx context.Context, tx pgx.Tx, user *model.User) (err error) + funcCreateTxOrigin string + inspectFuncCreateTx func(ctx context.Context, tx pgx.Tx, user *model.User) + afterCreateTxCounter uint64 + beforeCreateTxCounter uint64 + CreateTxMock mUserRepositoryMockCreateTx + funcFindByEmailHash func(ctx context.Context, emailHash string) (up1 *model.User, err error) funcFindByEmailHashOrigin string inspectFuncFindByEmailHash func(ctx context.Context, emailHash string) @@ -108,6 +115,9 @@ func NewUserRepositoryMock(t minimock.Tester) *UserRepositoryMock { m.CreateMock = mUserRepositoryMockCreate{mock: m} m.CreateMock.callArgs = []*UserRepositoryMockCreateParams{} + m.CreateTxMock = mUserRepositoryMockCreateTx{mock: m} + m.CreateTxMock.callArgs = []*UserRepositoryMockCreateTxParams{} + m.FindByEmailHashMock = mUserRepositoryMockFindByEmailHash{mock: m} m.FindByEmailHashMock.callArgs = []*UserRepositoryMockFindByEmailHashParams{} @@ -1193,6 +1203,379 @@ func (m *UserRepositoryMock) MinimockCreateInspect() { } } +type mUserRepositoryMockCreateTx struct { + optional bool + mock *UserRepositoryMock + defaultExpectation *UserRepositoryMockCreateTxExpectation + expectations []*UserRepositoryMockCreateTxExpectation + + callArgs []*UserRepositoryMockCreateTxParams + mutex sync.RWMutex + + expectedInvocations uint64 + expectedInvocationsOrigin string +} + +// UserRepositoryMockCreateTxExpectation specifies expectation struct of the UserRepository.CreateTx +type UserRepositoryMockCreateTxExpectation struct { + mock *UserRepositoryMock + params *UserRepositoryMockCreateTxParams + paramPtrs *UserRepositoryMockCreateTxParamPtrs + expectationOrigins UserRepositoryMockCreateTxExpectationOrigins + results *UserRepositoryMockCreateTxResults + returnOrigin string + Counter uint64 +} + +// UserRepositoryMockCreateTxParams contains parameters of the UserRepository.CreateTx +type UserRepositoryMockCreateTxParams struct { + ctx context.Context + tx pgx.Tx + user *model.User +} + +// UserRepositoryMockCreateTxParamPtrs contains pointers to parameters of the UserRepository.CreateTx +type UserRepositoryMockCreateTxParamPtrs struct { + ctx *context.Context + tx *pgx.Tx + user **model.User +} + +// UserRepositoryMockCreateTxResults contains results of the UserRepository.CreateTx +type UserRepositoryMockCreateTxResults struct { + err error +} + +// UserRepositoryMockCreateTxOrigins contains origins of expectations of the UserRepository.CreateTx +type UserRepositoryMockCreateTxExpectationOrigins struct { + origin string + originCtx string + originTx string + originUser string +} + +// Marks this method to be optional. The default behavior of any method with Return() is '1 or more', meaning +// the test will fail minimock's automatic final call check if the mocked method was not called at least once. +// Optional() makes method check to work in '0 or more' mode. +// It is NOT RECOMMENDED to use this option unless you really need it, as default behaviour helps to +// catch the problems when the expected method call is totally skipped during test run. +func (mmCreateTx *mUserRepositoryMockCreateTx) Optional() *mUserRepositoryMockCreateTx { + mmCreateTx.optional = true + return mmCreateTx +} + +// Expect sets up expected params for UserRepository.CreateTx +func (mmCreateTx *mUserRepositoryMockCreateTx) Expect(ctx context.Context, tx pgx.Tx, user *model.User) *mUserRepositoryMockCreateTx { + if mmCreateTx.mock.funcCreateTx != nil { + mmCreateTx.mock.t.Fatalf("UserRepositoryMock.CreateTx mock is already set by Set") + } + + if mmCreateTx.defaultExpectation == nil { + mmCreateTx.defaultExpectation = &UserRepositoryMockCreateTxExpectation{} + } + + if mmCreateTx.defaultExpectation.paramPtrs != nil { + mmCreateTx.mock.t.Fatalf("UserRepositoryMock.CreateTx mock is already set by ExpectParams functions") + } + + mmCreateTx.defaultExpectation.params = &UserRepositoryMockCreateTxParams{ctx, tx, user} + mmCreateTx.defaultExpectation.expectationOrigins.origin = minimock.CallerInfo(1) + for _, e := range mmCreateTx.expectations { + if minimock.Equal(e.params, mmCreateTx.defaultExpectation.params) { + mmCreateTx.mock.t.Fatalf("Expectation set by When has same params: %#v", *mmCreateTx.defaultExpectation.params) + } + } + + return mmCreateTx +} + +// ExpectCtxParam1 sets up expected param ctx for UserRepository.CreateTx +func (mmCreateTx *mUserRepositoryMockCreateTx) ExpectCtxParam1(ctx context.Context) *mUserRepositoryMockCreateTx { + if mmCreateTx.mock.funcCreateTx != nil { + mmCreateTx.mock.t.Fatalf("UserRepositoryMock.CreateTx mock is already set by Set") + } + + if mmCreateTx.defaultExpectation == nil { + mmCreateTx.defaultExpectation = &UserRepositoryMockCreateTxExpectation{} + } + + if mmCreateTx.defaultExpectation.params != nil { + mmCreateTx.mock.t.Fatalf("UserRepositoryMock.CreateTx mock is already set by Expect") + } + + if mmCreateTx.defaultExpectation.paramPtrs == nil { + mmCreateTx.defaultExpectation.paramPtrs = &UserRepositoryMockCreateTxParamPtrs{} + } + mmCreateTx.defaultExpectation.paramPtrs.ctx = &ctx + mmCreateTx.defaultExpectation.expectationOrigins.originCtx = minimock.CallerInfo(1) + + return mmCreateTx +} + +// ExpectTxParam2 sets up expected param tx for UserRepository.CreateTx +func (mmCreateTx *mUserRepositoryMockCreateTx) ExpectTxParam2(tx pgx.Tx) *mUserRepositoryMockCreateTx { + if mmCreateTx.mock.funcCreateTx != nil { + mmCreateTx.mock.t.Fatalf("UserRepositoryMock.CreateTx mock is already set by Set") + } + + if mmCreateTx.defaultExpectation == nil { + mmCreateTx.defaultExpectation = &UserRepositoryMockCreateTxExpectation{} + } + + if mmCreateTx.defaultExpectation.params != nil { + mmCreateTx.mock.t.Fatalf("UserRepositoryMock.CreateTx mock is already set by Expect") + } + + if mmCreateTx.defaultExpectation.paramPtrs == nil { + mmCreateTx.defaultExpectation.paramPtrs = &UserRepositoryMockCreateTxParamPtrs{} + } + mmCreateTx.defaultExpectation.paramPtrs.tx = &tx + mmCreateTx.defaultExpectation.expectationOrigins.originTx = minimock.CallerInfo(1) + + return mmCreateTx +} + +// ExpectUserParam3 sets up expected param user for UserRepository.CreateTx +func (mmCreateTx *mUserRepositoryMockCreateTx) ExpectUserParam3(user *model.User) *mUserRepositoryMockCreateTx { + if mmCreateTx.mock.funcCreateTx != nil { + mmCreateTx.mock.t.Fatalf("UserRepositoryMock.CreateTx mock is already set by Set") + } + + if mmCreateTx.defaultExpectation == nil { + mmCreateTx.defaultExpectation = &UserRepositoryMockCreateTxExpectation{} + } + + if mmCreateTx.defaultExpectation.params != nil { + mmCreateTx.mock.t.Fatalf("UserRepositoryMock.CreateTx mock is already set by Expect") + } + + if mmCreateTx.defaultExpectation.paramPtrs == nil { + mmCreateTx.defaultExpectation.paramPtrs = &UserRepositoryMockCreateTxParamPtrs{} + } + mmCreateTx.defaultExpectation.paramPtrs.user = &user + mmCreateTx.defaultExpectation.expectationOrigins.originUser = minimock.CallerInfo(1) + + return mmCreateTx +} + +// Inspect accepts an inspector function that has same arguments as the UserRepository.CreateTx +func (mmCreateTx *mUserRepositoryMockCreateTx) Inspect(f func(ctx context.Context, tx pgx.Tx, user *model.User)) *mUserRepositoryMockCreateTx { + if mmCreateTx.mock.inspectFuncCreateTx != nil { + mmCreateTx.mock.t.Fatalf("Inspect function is already set for UserRepositoryMock.CreateTx") + } + + mmCreateTx.mock.inspectFuncCreateTx = f + + return mmCreateTx +} + +// Return sets up results that will be returned by UserRepository.CreateTx +func (mmCreateTx *mUserRepositoryMockCreateTx) Return(err error) *UserRepositoryMock { + if mmCreateTx.mock.funcCreateTx != nil { + mmCreateTx.mock.t.Fatalf("UserRepositoryMock.CreateTx mock is already set by Set") + } + + if mmCreateTx.defaultExpectation == nil { + mmCreateTx.defaultExpectation = &UserRepositoryMockCreateTxExpectation{mock: mmCreateTx.mock} + } + mmCreateTx.defaultExpectation.results = &UserRepositoryMockCreateTxResults{err} + mmCreateTx.defaultExpectation.returnOrigin = minimock.CallerInfo(1) + return mmCreateTx.mock +} + +// Set uses given function f to mock the UserRepository.CreateTx method +func (mmCreateTx *mUserRepositoryMockCreateTx) Set(f func(ctx context.Context, tx pgx.Tx, user *model.User) (err error)) *UserRepositoryMock { + if mmCreateTx.defaultExpectation != nil { + mmCreateTx.mock.t.Fatalf("Default expectation is already set for the UserRepository.CreateTx method") + } + + if len(mmCreateTx.expectations) > 0 { + mmCreateTx.mock.t.Fatalf("Some expectations are already set for the UserRepository.CreateTx method") + } + + mmCreateTx.mock.funcCreateTx = f + mmCreateTx.mock.funcCreateTxOrigin = minimock.CallerInfo(1) + return mmCreateTx.mock +} + +// When sets expectation for the UserRepository.CreateTx which will trigger the result defined by the following +// Then helper +func (mmCreateTx *mUserRepositoryMockCreateTx) When(ctx context.Context, tx pgx.Tx, user *model.User) *UserRepositoryMockCreateTxExpectation { + if mmCreateTx.mock.funcCreateTx != nil { + mmCreateTx.mock.t.Fatalf("UserRepositoryMock.CreateTx mock is already set by Set") + } + + expectation := &UserRepositoryMockCreateTxExpectation{ + mock: mmCreateTx.mock, + params: &UserRepositoryMockCreateTxParams{ctx, tx, user}, + expectationOrigins: UserRepositoryMockCreateTxExpectationOrigins{origin: minimock.CallerInfo(1)}, + } + mmCreateTx.expectations = append(mmCreateTx.expectations, expectation) + return expectation +} + +// Then sets up UserRepository.CreateTx return parameters for the expectation previously defined by the When method +func (e *UserRepositoryMockCreateTxExpectation) Then(err error) *UserRepositoryMock { + e.results = &UserRepositoryMockCreateTxResults{err} + return e.mock +} + +// Times sets number of times UserRepository.CreateTx should be invoked +func (mmCreateTx *mUserRepositoryMockCreateTx) Times(n uint64) *mUserRepositoryMockCreateTx { + if n == 0 { + mmCreateTx.mock.t.Fatalf("Times of UserRepositoryMock.CreateTx mock can not be zero") + } + mm_atomic.StoreUint64(&mmCreateTx.expectedInvocations, n) + mmCreateTx.expectedInvocationsOrigin = minimock.CallerInfo(1) + return mmCreateTx +} + +func (mmCreateTx *mUserRepositoryMockCreateTx) invocationsDone() bool { + if len(mmCreateTx.expectations) == 0 && mmCreateTx.defaultExpectation == nil && mmCreateTx.mock.funcCreateTx == nil { + return true + } + + totalInvocations := mm_atomic.LoadUint64(&mmCreateTx.mock.afterCreateTxCounter) + expectedInvocations := mm_atomic.LoadUint64(&mmCreateTx.expectedInvocations) + + return totalInvocations > 0 && (expectedInvocations == 0 || expectedInvocations == totalInvocations) +} + +// CreateTx implements mm_repository.UserRepository +func (mmCreateTx *UserRepositoryMock) CreateTx(ctx context.Context, tx pgx.Tx, user *model.User) (err error) { + mm_atomic.AddUint64(&mmCreateTx.beforeCreateTxCounter, 1) + defer mm_atomic.AddUint64(&mmCreateTx.afterCreateTxCounter, 1) + + mmCreateTx.t.Helper() + + if mmCreateTx.inspectFuncCreateTx != nil { + mmCreateTx.inspectFuncCreateTx(ctx, tx, user) + } + + mm_params := UserRepositoryMockCreateTxParams{ctx, tx, user} + + // Record call args + mmCreateTx.CreateTxMock.mutex.Lock() + mmCreateTx.CreateTxMock.callArgs = append(mmCreateTx.CreateTxMock.callArgs, &mm_params) + mmCreateTx.CreateTxMock.mutex.Unlock() + + for _, e := range mmCreateTx.CreateTxMock.expectations { + if minimock.Equal(*e.params, mm_params) { + mm_atomic.AddUint64(&e.Counter, 1) + return e.results.err + } + } + + if mmCreateTx.CreateTxMock.defaultExpectation != nil { + mm_atomic.AddUint64(&mmCreateTx.CreateTxMock.defaultExpectation.Counter, 1) + mm_want := mmCreateTx.CreateTxMock.defaultExpectation.params + mm_want_ptrs := mmCreateTx.CreateTxMock.defaultExpectation.paramPtrs + + mm_got := UserRepositoryMockCreateTxParams{ctx, tx, user} + + if mm_want_ptrs != nil { + + if mm_want_ptrs.ctx != nil && !minimock.Equal(*mm_want_ptrs.ctx, mm_got.ctx) { + mmCreateTx.t.Errorf("UserRepositoryMock.CreateTx got unexpected parameter ctx, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmCreateTx.CreateTxMock.defaultExpectation.expectationOrigins.originCtx, *mm_want_ptrs.ctx, mm_got.ctx, minimock.Diff(*mm_want_ptrs.ctx, mm_got.ctx)) + } + + if mm_want_ptrs.tx != nil && !minimock.Equal(*mm_want_ptrs.tx, mm_got.tx) { + mmCreateTx.t.Errorf("UserRepositoryMock.CreateTx got unexpected parameter tx, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmCreateTx.CreateTxMock.defaultExpectation.expectationOrigins.originTx, *mm_want_ptrs.tx, mm_got.tx, minimock.Diff(*mm_want_ptrs.tx, mm_got.tx)) + } + + if mm_want_ptrs.user != nil && !minimock.Equal(*mm_want_ptrs.user, mm_got.user) { + mmCreateTx.t.Errorf("UserRepositoryMock.CreateTx got unexpected parameter user, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmCreateTx.CreateTxMock.defaultExpectation.expectationOrigins.originUser, *mm_want_ptrs.user, mm_got.user, minimock.Diff(*mm_want_ptrs.user, mm_got.user)) + } + + } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { + mmCreateTx.t.Errorf("UserRepositoryMock.CreateTx got unexpected parameters, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmCreateTx.CreateTxMock.defaultExpectation.expectationOrigins.origin, *mm_want, mm_got, minimock.Diff(*mm_want, mm_got)) + } + + mm_results := mmCreateTx.CreateTxMock.defaultExpectation.results + if mm_results == nil { + mmCreateTx.t.Fatal("No results are set for the UserRepositoryMock.CreateTx") + } + return (*mm_results).err + } + if mmCreateTx.funcCreateTx != nil { + return mmCreateTx.funcCreateTx(ctx, tx, user) + } + mmCreateTx.t.Fatalf("Unexpected call to UserRepositoryMock.CreateTx. %v %v %v", ctx, tx, user) + return +} + +// CreateTxAfterCounter returns a count of finished UserRepositoryMock.CreateTx invocations +func (mmCreateTx *UserRepositoryMock) CreateTxAfterCounter() uint64 { + return mm_atomic.LoadUint64(&mmCreateTx.afterCreateTxCounter) +} + +// CreateTxBeforeCounter returns a count of UserRepositoryMock.CreateTx invocations +func (mmCreateTx *UserRepositoryMock) CreateTxBeforeCounter() uint64 { + return mm_atomic.LoadUint64(&mmCreateTx.beforeCreateTxCounter) +} + +// Calls returns a list of arguments used in each call to UserRepositoryMock.CreateTx. +// The list is in the same order as the calls were made (i.e. recent calls have a higher index) +func (mmCreateTx *mUserRepositoryMockCreateTx) Calls() []*UserRepositoryMockCreateTxParams { + mmCreateTx.mutex.RLock() + + argCopy := make([]*UserRepositoryMockCreateTxParams, len(mmCreateTx.callArgs)) + copy(argCopy, mmCreateTx.callArgs) + + mmCreateTx.mutex.RUnlock() + + return argCopy +} + +// MinimockCreateTxDone returns true if the count of the CreateTx invocations corresponds +// the number of defined expectations +func (m *UserRepositoryMock) MinimockCreateTxDone() bool { + if m.CreateTxMock.optional { + // Optional methods provide '0 or more' call count restriction. + return true + } + + for _, e := range m.CreateTxMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + return false + } + } + + return m.CreateTxMock.invocationsDone() +} + +// MinimockCreateTxInspect logs each unmet expectation +func (m *UserRepositoryMock) MinimockCreateTxInspect() { + for _, e := range m.CreateTxMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + m.t.Errorf("Expected call to UserRepositoryMock.CreateTx at\n%s with params: %#v", e.expectationOrigins.origin, *e.params) + } + } + + afterCreateTxCounter := mm_atomic.LoadUint64(&m.afterCreateTxCounter) + // if default expectation was set then invocations count should be greater than zero + if m.CreateTxMock.defaultExpectation != nil && afterCreateTxCounter < 1 { + if m.CreateTxMock.defaultExpectation.params == nil { + m.t.Errorf("Expected call to UserRepositoryMock.CreateTx at\n%s", m.CreateTxMock.defaultExpectation.returnOrigin) + } else { + m.t.Errorf("Expected call to UserRepositoryMock.CreateTx at\n%s with params: %#v", m.CreateTxMock.defaultExpectation.expectationOrigins.origin, *m.CreateTxMock.defaultExpectation.params) + } + } + // if func was set then invocations count should be greater than zero + if m.funcCreateTx != nil && afterCreateTxCounter < 1 { + m.t.Errorf("Expected call to UserRepositoryMock.CreateTx at\n%s", m.funcCreateTxOrigin) + } + + if !m.CreateTxMock.invocationsDone() && afterCreateTxCounter > 0 { + m.t.Errorf("Expected %d calls to UserRepositoryMock.CreateTx at\n%s but found %d calls", + mm_atomic.LoadUint64(&m.CreateTxMock.expectedInvocations), m.CreateTxMock.expectedInvocationsOrigin, afterCreateTxCounter) + } +} + type mUserRepositoryMockFindByEmailHash struct { optional bool mock *UserRepositoryMock @@ -3724,6 +4107,8 @@ func (m *UserRepositoryMock) MinimockFinish() { m.MinimockCreateInspect() + m.MinimockCreateTxInspect() + m.MinimockFindByEmailHashInspect() m.MinimockFindByIDInspect() @@ -3763,6 +4148,7 @@ func (m *UserRepositoryMock) minimockDone() bool { m.MinimockCheckInviteLimitDone() && m.MinimockCheckInviteLimitTxDone() && m.MinimockCreateDone() && + m.MinimockCreateTxDone() && m.MinimockFindByEmailHashDone() && m.MinimockFindByIDDone() && m.MinimockGetBalanceDone() && diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go index a9e19fc..c4a98ad 100644 --- a/internal/repository/interfaces.go +++ b/internal/repository/interfaces.go @@ -12,6 +12,7 @@ type UserRepository interface { FindByEmailHash(ctx context.Context, emailHash string) (*model.User, error) FindByID(ctx context.Context, userID int) (*model.User, error) Create(ctx context.Context, user *model.User) error + CreateTx(ctx context.Context, tx pgx.Tx, user *model.User) error UpdateBalance(ctx context.Context, userID int, delta float64) error UpdateBalanceTx(ctx context.Context, tx pgx.Tx, userID int, delta float64) error GetBalance(ctx context.Context, userID int) (float64, error) @@ -35,7 +36,9 @@ type InviteRepository interface { Create(ctx context.Context, invite *model.InviteCode) error CreateTx(ctx context.Context, tx pgx.Tx, invite *model.InviteCode) error FindByCode(ctx context.Context, code int64) (*model.InviteCode, error) + FindActiveByCode(ctx context.Context, code int64) (*model.InviteCode, error) IncrementUsedCount(ctx context.Context, code int64) error + DecrementCanBeUsedCountTx(ctx context.Context, tx pgx.Tx, code int64) error DeactivateExpired(ctx context.Context) (int, error) GetUserInvites(ctx context.Context, userID int) ([]*model.InviteCode, error) } diff --git a/internal/repository/invite.go b/internal/repository/invite.go index b796876..f1453ab 100644 --- a/internal/repository/invite.go +++ b/internal/repository/invite.go @@ -78,6 +78,38 @@ func (r *inviteRepository) FindByCode(ctx context.Context, code int64) (*model.I return invite, nil } +func (r *inviteRepository) FindActiveByCode(ctx context.Context, code int64) (*model.InviteCode, error) { + query := r.qb.Select( + "id", "user_id", "code", "can_be_used_count", "used_count", + "is_active", "created_at", "expires_at", + ).From("invite_codes").Where(sq.And{ + sq.Eq{"code": code}, + sq.Eq{"is_active": true}, + sq.Expr("expires_at > now()"), + sq.Expr("can_be_used_count > used_count"), + }) + + sqlQuery, args, err := query.ToSql() + if err != nil { + return nil, errs.NewInternalError(errs.DatabaseError, "failed to build query", err) + } + + invite := &model.InviteCode{} + err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan( + &invite.ID, &invite.UserID, &invite.Code, &invite.CanBeUsedCount, + &invite.UsedCount, &invite.IsActive, &invite.CreatedAt, &invite.ExpiresAt, + ) + + if errors.Is(err, pgx.ErrNoRows) { + return nil, errs.NewBusinessError(errs.InviteInvalidOrExpired, "invite code is invalid or expired") + } + if err != nil { + return nil, errs.NewInternalError(errs.DatabaseError, "failed to find active invite code", err) + } + + return invite, nil +} + func (r *inviteRepository) IncrementUsedCount(ctx context.Context, code int64) error { query := r.qb.Update("invite_codes"). Set("used_count", sq.Expr("used_count + 1")). @@ -96,6 +128,25 @@ func (r *inviteRepository) IncrementUsedCount(ctx context.Context, code int64) e return nil } +func (r *inviteRepository) DecrementCanBeUsedCountTx(ctx context.Context, tx pgx.Tx, code int64) error { + query := r.qb.Update("invite_codes"). + Set("used_count", sq.Expr("used_count + 1")). + Set("is_active", sq.Expr("CASE WHEN used_count + 1 >= can_be_used_count THEN false ELSE is_active END")). + Where(sq.Eq{"code": code}) + + sqlQuery, args, err := query.ToSql() + if err != nil { + return errs.NewInternalError(errs.DatabaseError, "failed to build query", err) + } + + _, err = tx.Exec(ctx, sqlQuery, args...) + if err != nil { + return errs.NewInternalError(errs.DatabaseError, "failed to decrement can_be_used_count", err) + } + + return nil +} + func (r *inviteRepository) DeactivateExpired(ctx context.Context) (int, error) { query := r.qb.Update("invite_codes"). Set("is_active", false). diff --git a/internal/repository/user.go b/internal/repository/user.go index 5533417..85ced3a 100644 --- a/internal/repository/user.go +++ b/internal/repository/user.go @@ -86,6 +86,14 @@ func (r *userRepository) FindByID(ctx context.Context, userID int) (*model.User, } func (r *userRepository) Create(ctx context.Context, user *model.User) error { + return r.createWithExecutor(ctx, r.pool, user) +} + +func (r *userRepository) CreateTx(ctx context.Context, tx pgx.Tx, user *model.User) error { + return r.createWithExecutor(ctx, tx, user) +} + +func (r *userRepository) createWithExecutor(ctx context.Context, exec DBTX, user *model.User) error { encryptedEmail, err := r.cryptoHelper.Encrypt(user.Email) if err != nil { return errs.NewInternalError(errs.EncryptionError, "failed to encrypt email", err) @@ -114,7 +122,7 @@ func (r *userRepository) Create(ctx context.Context, user *model.User) error { return errs.NewInternalError(errs.DatabaseError, "failed to build query", err) } - err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan(&user.ID) + err = exec.QueryRow(ctx, sqlQuery, args...).Scan(&user.ID) if err != nil { return errs.NewInternalError(errs.DatabaseError, "failed to create user", err) } diff --git a/internal/service/auth.go b/internal/service/auth.go index e5063f3..a14d576 100644 --- a/internal/service/auth.go +++ b/internal/service/auth.go @@ -9,19 +9,24 @@ import ( "git.techease.ru/Smart-search/smart-search-back/pkg/crypto" "git.techease.ru/Smart-search/smart-search-back/pkg/errors" "git.techease.ru/Smart-search/smart-search-back/pkg/jwt" + "github.com/jackc/pgx/v5" ) type authService struct { userRepo repository.UserRepository sessionRepo repository.SessionRepository + inviteRepo repository.InviteRepository + txManager *repository.TxManager jwtSecret string cryptoHelper *crypto.Crypto } -func NewAuthService(userRepo repository.UserRepository, sessionRepo repository.SessionRepository, jwtSecret, cryptoSecret string) AuthService { +func NewAuthService(userRepo repository.UserRepository, sessionRepo repository.SessionRepository, inviteRepo repository.InviteRepository, txManager *repository.TxManager, jwtSecret, cryptoSecret string) AuthService { return &authService{ userRepo: userRepo, sessionRepo: sessionRepo, + inviteRepo: inviteRepo, + txManager: txManager, jwtSecret: jwtSecret, cryptoHelper: crypto.NewCrypto(cryptoSecret), } @@ -114,3 +119,68 @@ func (s *authService) Validate(ctx context.Context, accessToken string) (int, er func (s *authService) Logout(ctx context.Context, accessToken string) error { return s.sessionRepo.RevokeByAccessToken(ctx, accessToken) } + +func (s *authService) Register(ctx context.Context, email, password, name, phone string, inviteCode int64, ip, userAgent string) (accessToken, refreshToken string, err error) { + _, err = s.inviteRepo.FindActiveByCode(ctx, inviteCode) + if err != nil { + return "", "", err + } + + emailHash := s.cryptoHelper.EmailHash(email) + existingUser, err := s.userRepo.FindByEmailHash(ctx, emailHash) + if existingUser != nil { + return "", "", errors.NewBusinessError(errors.EmailAlreadyExists, "email already registered") + } + if err != nil && !errors.IsBusinessError(err, errors.UserNotFound) { + return "", "", err + } + + user := &model.User{ + Email: email, + EmailHash: emailHash, + PasswordHash: crypto.PasswordHash(password), + Phone: phone, + UserName: name, + Balance: 0, + } + + err = s.txManager.WithTx(ctx, func(tx pgx.Tx) error { + if err := s.userRepo.CreateTx(ctx, tx, user); err != nil { + return err + } + + if err := s.inviteRepo.DecrementCanBeUsedCountTx(ctx, tx, inviteCode); err != nil { + return err + } + + return nil + }) + if err != nil { + return "", "", err + } + + accessToken, err = jwt.GenerateAccessToken(user.ID, s.jwtSecret) + if err != nil { + return "", "", errors.NewInternalError(errors.InternalError, "failed to generate access token", err) + } + + refreshToken, err = jwt.GenerateRefreshToken(user.ID, s.jwtSecret) + if err != nil { + return "", "", errors.NewInternalError(errors.InternalError, "failed to generate refresh token", err) + } + + session := &model.Session{ + UserID: user.ID, + AccessToken: accessToken, + RefreshToken: refreshToken, + IP: ip, + UserAgent: userAgent, + ExpiresAt: time.Now().Add(30 * 24 * time.Hour), + } + + if err := s.sessionRepo.Create(ctx, session); err != nil { + return "", "", err + } + + return accessToken, refreshToken, nil +} diff --git a/internal/service/interfaces.go b/internal/service/interfaces.go index c8e2ec9..ffa09fa 100644 --- a/internal/service/interfaces.go +++ b/internal/service/interfaces.go @@ -9,6 +9,7 @@ import ( ) type AuthService interface { + Register(ctx context.Context, email, password, name, phone string, inviteCode int64, ip, userAgent string) (accessToken, refreshToken string, err error) Login(ctx context.Context, email, password, ip, userAgent string) (accessToken, refreshToken string, err error) Refresh(ctx context.Context, refreshToken string) (string, error) Validate(ctx context.Context, accessToken string) (int, error) diff --git a/internal/service/tests/auth_suite_test.go b/internal/service/tests/auth_suite_test.go index 432b4fb..4619e46 100644 --- a/internal/service/tests/auth_suite_test.go +++ b/internal/service/tests/auth_suite_test.go @@ -29,6 +29,7 @@ type Suite struct { authService service.AuthService userRepo *mocks.UserRepositoryMock sessionRepo *mocks.SessionRepositoryMock + inviteRepo *mocks.InviteRepositoryMock crypto *crypto.Crypto } @@ -52,9 +53,10 @@ func (s *Suite) SetupTest() { s.userRepo = mocks.NewUserRepositoryMock(ctrl) s.sessionRepo = mocks.NewSessionRepositoryMock(ctrl) + s.inviteRepo = mocks.NewInviteRepositoryMock(ctrl) s.crypto = crypto.NewCrypto(testCryptoSecret) - s.authService = service.NewAuthService(s.userRepo, s.sessionRepo, testJWTSecret, testCryptoSecret) + s.authService = service.NewAuthService(s.userRepo, s.sessionRepo, s.inviteRepo, nil, testJWTSecret, testCryptoSecret) } func createTestUser(password string) *model.User { diff --git a/pkg/errors/codes.go b/pkg/errors/codes.go index aeab8a1..916bdbd 100644 --- a/pkg/errors/codes.go +++ b/pkg/errors/codes.go @@ -6,6 +6,8 @@ const ( AuthInvalidToken = "AUTH_INVALID_TOKEN" RefreshInvalid = "REFRESH_INVALID" InviteLimitReached = "INVITE_LIMIT_REACHED" + InviteInvalidOrExpired = "INVITE_INVALID_OR_EXPIRED" + EmailAlreadyExists = "EMAIL_ALREADY_EXISTS" InsufficientBalance = "INSUFFICIENT_BALANCE" UserNotFound = "USER_NOT_FOUND" RequestNotFound = "REQUEST_NOT_FOUND" diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index 06896c1..b4f6ff3 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -51,6 +51,14 @@ func NewInternalError(code, message string, err error) *AppError { } } +func IsBusinessError(err error, code string) bool { + var appErr *AppError + if !errors.As(err, &appErr) { + return false + } + return appErr.Type == BusinessError && appErr.Code == code +} + func ToGRPCError(err error, zapLogger *zap.Logger, method string) error { if err == nil { return nil @@ -85,8 +93,10 @@ func ToGRPCError(err error, zapLogger *zap.Logger, method string) error { return status.Error(codes.Unauthenticated, appErr.Message) case InviteLimitReached: return status.Error(codes.ResourceExhausted, appErr.Message) - case InsufficientBalance: + case InsufficientBalance, InviteInvalidOrExpired: return status.Error(codes.FailedPrecondition, appErr.Message) + case EmailAlreadyExists: + return status.Error(codes.AlreadyExists, appErr.Message) case UserNotFound, RequestNotFound: return status.Error(codes.NotFound, appErr.Message) default: diff --git a/pkg/pb/auth/auth.pb.go b/pkg/pb/auth/auth.pb.go index 5d83932..438b1c0 100644 --- a/pkg/pb/auth/auth.pb.go +++ b/pkg/pb/auth/auth.pb.go @@ -437,6 +437,150 @@ func (x *LogoutResponse) GetSuccess() bool { return false } +type RegisterRequest struct { + state protoimpl.MessageState `protogen:"open.v1"` + Email string `protobuf:"bytes,1,opt,name=email,proto3" json:"email,omitempty"` + Password string `protobuf:"bytes,2,opt,name=password,proto3" json:"password,omitempty"` + Name string `protobuf:"bytes,3,opt,name=name,proto3" json:"name,omitempty"` + Phone string `protobuf:"bytes,4,opt,name=phone,proto3" json:"phone,omitempty"` + InviteCode int64 `protobuf:"varint,5,opt,name=invite_code,json=inviteCode,proto3" json:"invite_code,omitempty"` + Ip string `protobuf:"bytes,6,opt,name=ip,proto3" json:"ip,omitempty"` + UserAgent string `protobuf:"bytes,7,opt,name=user_agent,json=userAgent,proto3" json:"user_agent,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RegisterRequest) Reset() { + *x = RegisterRequest{} + mi := &file_auth_auth_proto_msgTypes[8] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RegisterRequest) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RegisterRequest) ProtoMessage() {} + +func (x *RegisterRequest) ProtoReflect() protoreflect.Message { + mi := &file_auth_auth_proto_msgTypes[8] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RegisterRequest.ProtoReflect.Descriptor instead. +func (*RegisterRequest) Descriptor() ([]byte, []int) { + return file_auth_auth_proto_rawDescGZIP(), []int{8} +} + +func (x *RegisterRequest) GetEmail() string { + if x != nil { + return x.Email + } + return "" +} + +func (x *RegisterRequest) GetPassword() string { + if x != nil { + return x.Password + } + return "" +} + +func (x *RegisterRequest) GetName() string { + if x != nil { + return x.Name + } + return "" +} + +func (x *RegisterRequest) GetPhone() string { + if x != nil { + return x.Phone + } + return "" +} + +func (x *RegisterRequest) GetInviteCode() int64 { + if x != nil { + return x.InviteCode + } + return 0 +} + +func (x *RegisterRequest) GetIp() string { + if x != nil { + return x.Ip + } + return "" +} + +func (x *RegisterRequest) GetUserAgent() string { + if x != nil { + return x.UserAgent + } + return "" +} + +type RegisterResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + AccessToken string `protobuf:"bytes,1,opt,name=access_token,json=accessToken,proto3" json:"access_token,omitempty"` + RefreshToken string `protobuf:"bytes,2,opt,name=refresh_token,json=refreshToken,proto3" json:"refresh_token,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *RegisterResponse) Reset() { + *x = RegisterResponse{} + mi := &file_auth_auth_proto_msgTypes[9] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *RegisterResponse) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*RegisterResponse) ProtoMessage() {} + +func (x *RegisterResponse) ProtoReflect() protoreflect.Message { + mi := &file_auth_auth_proto_msgTypes[9] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use RegisterResponse.ProtoReflect.Descriptor instead. +func (*RegisterResponse) Descriptor() ([]byte, []int) { + return file_auth_auth_proto_rawDescGZIP(), []int{9} +} + +func (x *RegisterResponse) GetAccessToken() string { + if x != nil { + return x.AccessToken + } + return "" +} + +func (x *RegisterResponse) GetRefreshToken() string { + if x != nil { + return x.RefreshToken + } + return "" +} + var File_auth_auth_proto protoreflect.FileDescriptor const file_auth_auth_proto_rawDesc = "" + @@ -467,8 +611,22 @@ const file_auth_auth_proto_rawDesc = "" + "\rLogoutRequest\x12!\n" + "\faccess_token\x18\x01 \x01(\tR\vaccessToken\"*\n" + "\x0eLogoutResponse\x12\x18\n" + - "\asuccess\x18\x01 \x01(\bR\asuccess2\xe7\x01\n" + - "\vAuthService\x120\n" + + "\asuccess\x18\x01 \x01(\bR\asuccess\"\xbd\x01\n" + + "\x0fRegisterRequest\x12\x14\n" + + "\x05email\x18\x01 \x01(\tR\x05email\x12\x1a\n" + + "\bpassword\x18\x02 \x01(\tR\bpassword\x12\x12\n" + + "\x04name\x18\x03 \x01(\tR\x04name\x12\x14\n" + + "\x05phone\x18\x04 \x01(\tR\x05phone\x12\x1f\n" + + "\vinvite_code\x18\x05 \x01(\x03R\n" + + "inviteCode\x12\x0e\n" + + "\x02ip\x18\x06 \x01(\tR\x02ip\x12\x1d\n" + + "\n" + + "user_agent\x18\a \x01(\tR\tuserAgent\"Z\n" + + "\x10RegisterResponse\x12!\n" + + "\faccess_token\x18\x01 \x01(\tR\vaccessToken\x12#\n" + + "\rrefresh_token\x18\x02 \x01(\tR\frefreshToken2\xa2\x02\n" + + "\vAuthService\x129\n" + + "\bRegister\x12\x15.auth.RegisterRequest\x1a\x16.auth.RegisterResponse\x120\n" + "\x05Login\x12\x12.auth.LoginRequest\x1a\x13.auth.LoginResponse\x126\n" + "\aRefresh\x12\x14.auth.RefreshRequest\x1a\x15.auth.RefreshResponse\x129\n" + "\bValidate\x12\x15.auth.ValidateRequest\x1a\x16.auth.ValidateResponse\x123\n" + @@ -486,7 +644,7 @@ func file_auth_auth_proto_rawDescGZIP() []byte { return file_auth_auth_proto_rawDescData } -var file_auth_auth_proto_msgTypes = make([]protoimpl.MessageInfo, 8) +var file_auth_auth_proto_msgTypes = make([]protoimpl.MessageInfo, 10) var file_auth_auth_proto_goTypes = []any{ (*LoginRequest)(nil), // 0: auth.LoginRequest (*LoginResponse)(nil), // 1: auth.LoginResponse @@ -496,18 +654,22 @@ var file_auth_auth_proto_goTypes = []any{ (*ValidateResponse)(nil), // 5: auth.ValidateResponse (*LogoutRequest)(nil), // 6: auth.LogoutRequest (*LogoutResponse)(nil), // 7: auth.LogoutResponse + (*RegisterRequest)(nil), // 8: auth.RegisterRequest + (*RegisterResponse)(nil), // 9: auth.RegisterResponse } var file_auth_auth_proto_depIdxs = []int32{ - 0, // 0: auth.AuthService.Login:input_type -> auth.LoginRequest - 2, // 1: auth.AuthService.Refresh:input_type -> auth.RefreshRequest - 4, // 2: auth.AuthService.Validate:input_type -> auth.ValidateRequest - 6, // 3: auth.AuthService.Logout:input_type -> auth.LogoutRequest - 1, // 4: auth.AuthService.Login:output_type -> auth.LoginResponse - 3, // 5: auth.AuthService.Refresh:output_type -> auth.RefreshResponse - 5, // 6: auth.AuthService.Validate:output_type -> auth.ValidateResponse - 7, // 7: auth.AuthService.Logout:output_type -> auth.LogoutResponse - 4, // [4:8] is the sub-list for method output_type - 0, // [0:4] is the sub-list for method input_type + 8, // 0: auth.AuthService.Register:input_type -> auth.RegisterRequest + 0, // 1: auth.AuthService.Login:input_type -> auth.LoginRequest + 2, // 2: auth.AuthService.Refresh:input_type -> auth.RefreshRequest + 4, // 3: auth.AuthService.Validate:input_type -> auth.ValidateRequest + 6, // 4: auth.AuthService.Logout:input_type -> auth.LogoutRequest + 9, // 5: auth.AuthService.Register:output_type -> auth.RegisterResponse + 1, // 6: auth.AuthService.Login:output_type -> auth.LoginResponse + 3, // 7: auth.AuthService.Refresh:output_type -> auth.RefreshResponse + 5, // 8: auth.AuthService.Validate:output_type -> auth.ValidateResponse + 7, // 9: auth.AuthService.Logout:output_type -> auth.LogoutResponse + 5, // [5:10] is the sub-list for method output_type + 0, // [0:5] is the sub-list for method input_type 0, // [0:0] is the sub-list for extension type_name 0, // [0:0] is the sub-list for extension extendee 0, // [0:0] is the sub-list for field type_name @@ -524,7 +686,7 @@ func file_auth_auth_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_auth_auth_proto_rawDesc), len(file_auth_auth_proto_rawDesc)), NumEnums: 0, - NumMessages: 8, + NumMessages: 10, NumExtensions: 0, NumServices: 1, }, diff --git a/pkg/pb/auth/auth_grpc.pb.go b/pkg/pb/auth/auth_grpc.pb.go index 1b9c5a5..6a92b4a 100644 --- a/pkg/pb/auth/auth_grpc.pb.go +++ b/pkg/pb/auth/auth_grpc.pb.go @@ -19,6 +19,7 @@ import ( const _ = grpc.SupportPackageIsVersion9 const ( + AuthService_Register_FullMethodName = "/auth.AuthService/Register" AuthService_Login_FullMethodName = "/auth.AuthService/Login" AuthService_Refresh_FullMethodName = "/auth.AuthService/Refresh" AuthService_Validate_FullMethodName = "/auth.AuthService/Validate" @@ -29,6 +30,7 @@ const ( // // For semantics around ctx use and closing/ending streaming RPCs, please refer to https://pkg.go.dev/google.golang.org/grpc/?tab=doc#ClientConn.NewStream. type AuthServiceClient interface { + Register(ctx context.Context, in *RegisterRequest, opts ...grpc.CallOption) (*RegisterResponse, error) Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*LoginResponse, error) Refresh(ctx context.Context, in *RefreshRequest, opts ...grpc.CallOption) (*RefreshResponse, error) Validate(ctx context.Context, in *ValidateRequest, opts ...grpc.CallOption) (*ValidateResponse, error) @@ -43,6 +45,16 @@ func NewAuthServiceClient(cc grpc.ClientConnInterface) AuthServiceClient { return &authServiceClient{cc} } +func (c *authServiceClient) Register(ctx context.Context, in *RegisterRequest, opts ...grpc.CallOption) (*RegisterResponse, error) { + cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) + out := new(RegisterResponse) + err := c.cc.Invoke(ctx, AuthService_Register_FullMethodName, in, out, cOpts...) + if err != nil { + return nil, err + } + return out, nil +} + func (c *authServiceClient) Login(ctx context.Context, in *LoginRequest, opts ...grpc.CallOption) (*LoginResponse, error) { cOpts := append([]grpc.CallOption{grpc.StaticMethod()}, opts...) out := new(LoginResponse) @@ -87,6 +99,7 @@ func (c *authServiceClient) Logout(ctx context.Context, in *LogoutRequest, opts // All implementations must embed UnimplementedAuthServiceServer // for forward compatibility. type AuthServiceServer interface { + Register(context.Context, *RegisterRequest) (*RegisterResponse, error) Login(context.Context, *LoginRequest) (*LoginResponse, error) Refresh(context.Context, *RefreshRequest) (*RefreshResponse, error) Validate(context.Context, *ValidateRequest) (*ValidateResponse, error) @@ -101,6 +114,9 @@ type AuthServiceServer interface { // pointer dereference when methods are called. type UnimplementedAuthServiceServer struct{} +func (UnimplementedAuthServiceServer) Register(context.Context, *RegisterRequest) (*RegisterResponse, error) { + return nil, status.Errorf(codes.Unimplemented, "method Register not implemented") +} func (UnimplementedAuthServiceServer) Login(context.Context, *LoginRequest) (*LoginResponse, error) { return nil, status.Errorf(codes.Unimplemented, "method Login not implemented") } @@ -134,6 +150,24 @@ func RegisterAuthServiceServer(s grpc.ServiceRegistrar, srv AuthServiceServer) { s.RegisterService(&AuthService_ServiceDesc, srv) } +func _AuthService_Register_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { + in := new(RegisterRequest) + if err := dec(in); err != nil { + return nil, err + } + if interceptor == nil { + return srv.(AuthServiceServer).Register(ctx, in) + } + info := &grpc.UnaryServerInfo{ + Server: srv, + FullMethod: AuthService_Register_FullMethodName, + } + handler := func(ctx context.Context, req interface{}) (interface{}, error) { + return srv.(AuthServiceServer).Register(ctx, req.(*RegisterRequest)) + } + return interceptor(ctx, in, info, handler) +} + func _AuthService_Login_Handler(srv interface{}, ctx context.Context, dec func(interface{}) error, interceptor grpc.UnaryServerInterceptor) (interface{}, error) { in := new(LoginRequest) if err := dec(in); err != nil { @@ -213,6 +247,10 @@ var AuthService_ServiceDesc = grpc.ServiceDesc{ ServiceName: "auth.AuthService", HandlerType: (*AuthServiceServer)(nil), Methods: []grpc.MethodDesc{ + { + MethodName: "Register", + Handler: _AuthService_Register_Handler, + }, { MethodName: "Login", Handler: _AuthService_Login_Handler, diff --git a/tests/auth_handler_test.go b/tests/auth_handler_test.go index 0005023..def4a1b 100644 --- a/tests/auth_handler_test.go +++ b/tests/auth_handler_test.go @@ -156,3 +156,159 @@ func (s *IntegrationSuite) TestAuthHandler_LogoutInvalidatesSession() { s.NotNil(validateResp) s.False(validateResp.Valid) } + +func (s *IntegrationSuite) TestAuthHandler_RegisterSuccess() { + ctx := context.Background() + + inviteCode := s.createActiveInviteCode(5) + + registerReq := &authpb.RegisterRequest{ + Email: "newuser@example.com", + Password: "newpassword123", + Name: "New User", + Phone: "+1234567890", + InviteCode: inviteCode, + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + registerResp, err := s.authClient.Register(ctx, registerReq) + s.NoError(err) + s.NotNil(registerResp) + s.NotEmpty(registerResp.AccessToken) + s.NotEmpty(registerResp.RefreshToken) + + validateReq := &authpb.ValidateRequest{ + AccessToken: registerResp.AccessToken, + } + + validateResp, err := s.authClient.Validate(ctx, validateReq) + s.NoError(err) + s.NotNil(validateResp) + s.True(validateResp.Valid) + s.Greater(validateResp.UserId, int64(0)) +} + +func (s *IntegrationSuite) TestAuthHandler_RegisterInvalidInviteCode() { + ctx := context.Background() + + registerReq := &authpb.RegisterRequest{ + Email: "newuser2@example.com", + Password: "newpassword123", + Name: "New User 2", + Phone: "+1234567891", + InviteCode: 999999, + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + registerResp, err := s.authClient.Register(ctx, registerReq) + s.Error(err) + s.Nil(registerResp) + + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.FailedPrecondition, st.Code()) +} + +func (s *IntegrationSuite) TestAuthHandler_RegisterExpiredInviteCode() { + ctx := context.Background() + + inviteCode := s.createExpiredInviteCode() + + registerReq := &authpb.RegisterRequest{ + Email: "newuser3@example.com", + Password: "newpassword123", + Name: "New User 3", + Phone: "+1234567892", + InviteCode: inviteCode, + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + registerResp, err := s.authClient.Register(ctx, registerReq) + s.Error(err) + s.Nil(registerResp) + + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.FailedPrecondition, st.Code()) +} + +func (s *IntegrationSuite) TestAuthHandler_RegisterExhaustedInviteCode() { + ctx := context.Background() + + inviteCode := s.createActiveInviteCode(1) + + registerReq1 := &authpb.RegisterRequest{ + Email: "newuser4@example.com", + Password: "newpassword123", + Name: "New User 4", + Phone: "+1234567893", + InviteCode: inviteCode, + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + registerResp1, err := s.authClient.Register(ctx, registerReq1) + s.NoError(err) + s.NotNil(registerResp1) + + registerReq2 := &authpb.RegisterRequest{ + Email: "newuser5@example.com", + Password: "newpassword123", + Name: "New User 5", + Phone: "+1234567894", + InviteCode: inviteCode, + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + registerResp2, err := s.authClient.Register(ctx, registerReq2) + s.Error(err) + s.Nil(registerResp2) + + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.FailedPrecondition, st.Code()) +} + +func (s *IntegrationSuite) TestAuthHandler_RegisterDuplicateEmail() { + ctx := context.Background() + + inviteCode := s.createActiveInviteCode(5) + + registerReq1 := &authpb.RegisterRequest{ + Email: "duplicate@example.com", + Password: "newpassword123", + Name: "Duplicate User", + Phone: "+1234567895", + InviteCode: inviteCode, + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + registerResp1, err := s.authClient.Register(ctx, registerReq1) + s.NoError(err) + s.NotNil(registerResp1) + + inviteCode2 := s.createActiveInviteCode(5) + + registerReq2 := &authpb.RegisterRequest{ + Email: "duplicate@example.com", + Password: "anotherpassword", + Name: "Another User", + Phone: "+1234567896", + InviteCode: inviteCode2, + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + registerResp2, err := s.authClient.Register(ctx, registerReq2) + s.Error(err) + s.Nil(registerResp2) + + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.AlreadyExists, st.Code()) +} diff --git a/tests/integration_suite_test.go b/tests/integration_suite_test.go index 1ab5ba5..ae3d204 100644 --- a/tests/integration_suite_test.go +++ b/tests/integration_suite_test.go @@ -78,7 +78,7 @@ func (s *IntegrationSuite) SetupSuite() { s.T().Logf("PostgreSQL connection string: %s", connStr) s.T().Log("Running migrations...") - err = database.RunMigrationsFromPath(connStr, "../../../migrations") + err = database.RunMigrationsFromPath(connStr, "../migrations") s.Require().NoError(err) s.T().Log("Creating connection pool...") @@ -176,6 +176,30 @@ func (s *IntegrationSuite) createTestUser(email, password string) { s.Require().NoError(err) } +func (s *IntegrationSuite) createActiveInviteCode(canBeUsedCount int) int64 { + var inviteCode int64 + query := ` + INSERT INTO invite_codes (user_id, code, can_be_used_count, used_count, is_active, expires_at) + VALUES (1, FLOOR(100000 + RANDOM() * 900000)::bigint, $1, 0, true, NOW() + INTERVAL '30 days') + RETURNING code + ` + err := s.pool.QueryRow(s.ctx, query, canBeUsedCount).Scan(&inviteCode) + s.Require().NoError(err) + return inviteCode +} + +func (s *IntegrationSuite) createExpiredInviteCode() int64 { + var inviteCode int64 + query := ` + INSERT INTO invite_codes (user_id, code, can_be_used_count, used_count, is_active, expires_at) + VALUES (1, FLOOR(100000 + RANDOM() * 900000)::bigint, 5, 0, true, NOW() - INTERVAL '1 day') + RETURNING code + ` + err := s.pool.QueryRow(s.ctx, query).Scan(&inviteCode) + s.Require().NoError(err) + return inviteCode +} + func (s *IntegrationSuite) TearDownSuite() { s.T().Log("Tearing down integration suite...")