From d3d004569e7bdcc6122f53dc6db7f8f6e338f89d Mon Sep 17 00:00:00 2001 From: vallyenfail Date: Mon, 19 Jan 2026 23:50:42 +0300 Subject: [PATCH] add service --- internal/repository/invite.go | 13 +- pkg/jwt/jwt.go | 7 +- tests/concurrent_ownership_test.go | 243 ++++++++++++++++++++++++++ tests/concurrent_registration_test.go | 178 +++++++++++++++++++ tests/concurrent_request_test.go | 219 +++++++++++++++++++++++ tests/idempotency_test.go | 155 ++++++++++++++++ tests/integration_suite_test.go | 101 +++++++++++ tests/transaction_rollback_test.go | 194 ++++++++++++++++++++ tests/worker_concurrent_test.go | 181 +++++++++++++++++++ 9 files changed, 1285 insertions(+), 6 deletions(-) create mode 100644 tests/concurrent_ownership_test.go create mode 100644 tests/concurrent_registration_test.go create mode 100644 tests/concurrent_request_test.go create mode 100644 tests/idempotency_test.go create mode 100644 tests/transaction_rollback_test.go create mode 100644 tests/worker_concurrent_test.go diff --git a/internal/repository/invite.go b/internal/repository/invite.go index 06ce28a..850a7a7 100644 --- a/internal/repository/invite.go +++ b/internal/repository/invite.go @@ -146,18 +146,27 @@ func (r *inviteRepository) DecrementCanBeUsedCountTx(ctx context.Context, tx pgx query := r.qb.Update("invite_codes"). Set("can_be_used_count", sq.Expr("can_be_used_count - 1")). Set("is_active", sq.Expr("CASE WHEN can_be_used_count - 1 <= 0 THEN false ELSE is_active END")). - Where(sq.Eq{"code": code}) + Where(sq.And{ + sq.Eq{"code": code}, + sq.Expr("can_be_used_count > 0"), + sq.Eq{"is_active": true}, + sq.Expr("expires_at > now()"), + }) sqlQuery, args, err := query.ToSql() if err != nil { return errs.NewInternalError(errs.DatabaseError, "failed to build query", err) } - _, err = tx.Exec(ctx, sqlQuery, args...) + result, err := tx.Exec(ctx, sqlQuery, args...) if err != nil { return errs.NewInternalError(errs.DatabaseError, "failed to decrement can_be_used_count", err) } + if result.RowsAffected() == 0 { + return errs.NewBusinessError(errs.InviteInvalidOrExpired, "invite code is invalid, expired, or exhausted") + } + return nil } diff --git a/pkg/jwt/jwt.go b/pkg/jwt/jwt.go index d786523..f693009 100644 --- a/pkg/jwt/jwt.go +++ b/pkg/jwt/jwt.go @@ -11,7 +11,6 @@ import ( ) type Claims struct { - Sub string `json:"sub"` Type string `json:"type"` jwt.RegisteredClaims } @@ -19,9 +18,9 @@ type Claims struct { func GenerateAccessToken(userID int, secret string) (string, error) { now := time.Now() claims := Claims{ - Sub: strconv.Itoa(userID), Type: "access", RegisteredClaims: jwt.RegisteredClaims{ + Subject: strconv.Itoa(userID), ID: uuid.New().String(), IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(15 * time.Minute)), @@ -35,9 +34,9 @@ func GenerateAccessToken(userID int, secret string) (string, error) { func GenerateRefreshToken(userID int, secret string) (string, error) { now := time.Now() claims := Claims{ - Sub: strconv.Itoa(userID), Type: "refresh", RegisteredClaims: jwt.RegisteredClaims{ + Subject: strconv.Itoa(userID), ID: uuid.New().String(), IssuedAt: jwt.NewNumericDate(now), ExpiresAt: jwt.NewNumericDate(now.Add(30 * 24 * time.Hour)), @@ -73,7 +72,7 @@ func GetUserIDFromToken(tokenString, secret string) (int, error) { return 0, err } - userID, err := strconv.Atoi(claims.Sub) + userID, err := strconv.Atoi(claims.Subject) if err != nil { return 0, fmt.Errorf("invalid user ID in token: %w", err) } diff --git a/tests/concurrent_ownership_test.go b/tests/concurrent_ownership_test.go new file mode 100644 index 0000000..5c21d1d --- /dev/null +++ b/tests/concurrent_ownership_test.go @@ -0,0 +1,243 @@ +package tests + +import ( + "sync" + "sync/atomic" + + authpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth" + requestpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/request" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func (s *IntegrationSuite) TestConcurrentOwnership_User2TriesApproveTZ_WhileUser1Creates() { + email1, password1, _ := s.createUniqueTestUser("owner1", 1000.0) + email2, password2, _ := s.createUniqueTestUser("attacker1", 1000.0) + + login1, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email1, + Password: password1, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validate1, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: login1.AccessToken, + }) + s.Require().NoError(err) + user1ID := validate1.UserId + + login2, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email2, + Password: password2, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validate2, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: login2.AccessToken, + }) + s.Require().NoError(err) + user2ID := validate2.UserId + + createResp, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: user1ID, + RequestTxt: "Request от User1 для теста ownership", + }) + s.Require().NoError(err) + requestID := createResp.RequestId + + var wg sync.WaitGroup + var user1Success, user2Denied int32 + + startBarrier := make(chan struct{}) + + wg.Add(1) + go func() { + defer wg.Done() + <-startBarrier + + _, err := s.requestClient.ApproveTZ(s.ctx, &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: "User1 approves", + UserId: user1ID, + }) + if err == nil { + atomic.AddInt32(&user1Success, 1) + } + }() + + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-startBarrier + + _, err := s.requestClient.ApproveTZ(s.ctx, &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: "User2 tries to approve", + UserId: user2ID, + }) + if err != nil { + st, ok := status.FromError(err) + if ok && st.Code() == codes.PermissionDenied { + atomic.AddInt32(&user2Denied, 1) + } + } + }() + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("User1 success: %d, User2 denied: %d", user1Success, user2Denied) + + s.Equal(int32(5), user2Denied, + "Все попытки User2 должны быть отклонены с PermissionDenied") +} + +func (s *IntegrationSuite) TestConcurrentOwnership_ConcurrentApproveTZ_SameRequest() { + email, password, _ := s.createUniqueTestUser("concurrent_approve", 1000.0) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + }) + s.Require().NoError(err) + userID := validateResp.UserId + + createResp, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: userID, + RequestTxt: "Request для concurrent ApproveTZ", + }) + s.Require().NoError(err) + requestID := createResp.RequestId + + var wg sync.WaitGroup + var successCount int32 + goroutines := 5 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-startBarrier + + _, err := s.requestClient.ApproveTZ(s.ctx, &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: "Concurrent approve attempt", + UserId: userID, + }) + if err == nil { + atomic.AddInt32(&successCount, 1) + } + }(i) + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Concurrent ApproveTZ success count: %d", successCount) + + suppliersCount := s.getRequestSuppliersCount(requestID) + s.T().Logf("Total suppliers for request: %d", suppliersCount) +} + +func (s *IntegrationSuite) TestConcurrentOwnership_SessionIsolation_AfterLogout() { + email1, password1, _ := s.createUniqueTestUser("session_iso1", 1000.0) + email2, password2, _ := s.createUniqueTestUser("session_iso2", 1000.0) + + login1, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email1, + Password: password1, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + user1Token := login1.AccessToken + + validate1, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: user1Token, + }) + s.Require().NoError(err) + s.True(validate1.Valid) + + login2, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email2, + Password: password2, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + user2Token := login2.AccessToken + + _, err = s.authClient.Logout(s.ctx, &authpb.LogoutRequest{ + AccessToken: user1Token, + }) + s.Require().NoError(err) + + validate1After, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: user1Token, + }) + s.NoError(err) + s.False(validate1After.Valid, + "Токен User1 должен быть невалиден после logout") + + validate2After, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: user2Token, + }) + s.NoError(err) + s.True(validate2After.Valid, + "Токен User2 должен оставаться валидным после logout User1") + + var wg sync.WaitGroup + var user1Invalid, user2Valid int32 + goroutines := 10 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-startBarrier + + if idx%2 == 0 { + resp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: user1Token, + }) + if err == nil && !resp.Valid { + atomic.AddInt32(&user1Invalid, 1) + } + } else { + resp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: user2Token, + }) + if err == nil && resp.Valid { + atomic.AddInt32(&user2Valid, 1) + } + } + }(i) + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Session isolation - User1 invalid: %d, User2 valid: %d", user1Invalid, user2Valid) + + s.Equal(int32(goroutines/2), user1Invalid, + "Все проверки токена User1 должны показать invalid") + s.Equal(int32(goroutines/2), user2Valid, + "Все проверки токена User2 должны показать valid") +} diff --git a/tests/concurrent_registration_test.go b/tests/concurrent_registration_test.go new file mode 100644 index 0000000..1554d79 --- /dev/null +++ b/tests/concurrent_registration_test.go @@ -0,0 +1,178 @@ +package tests + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + authpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth" +) + +func (s *IntegrationSuite) TestConcurrent_Registration_WithSingleInviteCode() { + maxUses := 3 + inviteCode := s.createActiveInviteCode(maxUses) + + var wg sync.WaitGroup + var successCount int32 + var errorCount int32 + goroutines := 20 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + <-startBarrier + + email := fmt.Sprintf("concurrent_reg_%d_%d@example.com", idx, time.Now().UnixNano()) + + _, err := s.authClient.Register(s.ctx, &authpb.RegisterRequest{ + Email: email, + Password: "testpassword123", + Name: fmt.Sprintf("User %d", idx), + Phone: fmt.Sprintf("+1%010d", idx), + InviteCode: inviteCode, + Ip: "127.0.0.1", + UserAgent: "integration-test", + }) + + if err == nil { + atomic.AddInt32(&successCount, 1) + } else { + atomic.AddInt32(&errorCount, 1) + } + }(i) + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Registration results - Success: %d, Errors: %d", successCount, errorCount) + + s.LessOrEqual(int(successCount), maxUses, + "Количество успешных регистраций (%d) не должно превышать лимит invite-кода (%d)", successCount, maxUses) + + remainingUses := s.getInviteCodeUsageCount(inviteCode) + s.T().Logf("Remaining invite code uses: %d", remainingUses) + + s.Equal(maxUses-int(successCount), remainingUses, + "Оставшееся количество использований должно соответствовать успешным регистрациям") +} + +func (s *IntegrationSuite) TestConcurrent_Registration_InviteCodeDeactivation() { + maxUses := 2 + inviteCode := s.createActiveInviteCode(maxUses) + + s.True(s.isInviteCodeActive(inviteCode), "Invite code должен быть активен изначально") + + var wg sync.WaitGroup + var successCount int32 + goroutines := 10 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + <-startBarrier + + email := fmt.Sprintf("deactivation_test_%d_%d@example.com", idx, time.Now().UnixNano()) + + _, err := s.authClient.Register(s.ctx, &authpb.RegisterRequest{ + Email: email, + Password: "testpassword123", + Name: fmt.Sprintf("User %d", idx), + Phone: fmt.Sprintf("+2%010d", idx), + InviteCode: inviteCode, + Ip: "127.0.0.1", + UserAgent: "integration-test", + }) + + if err == nil { + atomic.AddInt32(&successCount, 1) + } + }(i) + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Registration success count: %d", successCount) + + s.LessOrEqual(int(successCount), maxUses, + "Не должно быть больше %d успешных регистраций", maxUses) + + remainingUses := s.getInviteCodeUsageCount(inviteCode) + s.GreaterOrEqual(remainingUses, 0, + "Количество использований не должно быть отрицательным") +} + +func (s *IntegrationSuite) TestConcurrent_Registration_MultipleInviteCodes() { + inviteCode1 := s.createActiveInviteCode(2) + inviteCode2 := s.createActiveInviteCode(2) + + var wg sync.WaitGroup + var success1, success2 int32 + goroutines := 10 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + <-startBarrier + + var code int64 + if idx%2 == 0 { + code = inviteCode1 + } else { + code = inviteCode2 + } + + email := fmt.Sprintf("multi_invite_%d_%d@example.com", idx, time.Now().UnixNano()) + + _, err := s.authClient.Register(s.ctx, &authpb.RegisterRequest{ + Email: email, + Password: "testpassword123", + Name: fmt.Sprintf("User %d", idx), + Phone: fmt.Sprintf("+3%010d", idx), + InviteCode: code, + Ip: "127.0.0.1", + UserAgent: "integration-test", + }) + + if err == nil { + if code == inviteCode1 { + atomic.AddInt32(&success1, 1) + } else { + atomic.AddInt32(&success2, 1) + } + } + }(i) + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Multi-invite results - Code1: %d, Code2: %d", success1, success2) + + s.LessOrEqual(int(success1), 2, + "Invite code 1 не должен превышать лимит") + s.LessOrEqual(int(success2), 2, + "Invite code 2 не должен превышать лимит") + + remaining1 := s.getInviteCodeUsageCount(inviteCode1) + remaining2 := s.getInviteCodeUsageCount(inviteCode2) + + s.Equal(2-int(success1), remaining1, + "Остаток invite code 1 должен соответствовать успешным регистрациям") + s.Equal(2-int(success2), remaining2, + "Остаток invite code 2 должен соответствовать успешным регистрациям") +} diff --git a/tests/concurrent_request_test.go b/tests/concurrent_request_test.go new file mode 100644 index 0000000..78a5fe5 --- /dev/null +++ b/tests/concurrent_request_test.go @@ -0,0 +1,219 @@ +package tests + +import ( + "sync" + "sync/atomic" + + authpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth" + requestpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/request" +) + +func (s *IntegrationSuite) TestConcurrentRequest_CreateTZ_LimitedBalance() { + initialBalance := 50.0 + email, password, userID := s.createUniqueTestUser("limited_balance_tz", initialBalance) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + }) + s.Require().NoError(err) + + var wg sync.WaitGroup + var successCount int32 + var errorCount int32 + goroutines := 20 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-startBarrier + + _, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: "Параллельный CreateTZ с ограниченным балансом", + }) + + if err == nil { + atomic.AddInt32(&successCount, 1) + } else { + atomic.AddInt32(&errorCount, 1) + } + }(i) + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("CreateTZ with limited balance - Success: %d, Errors: %d", successCount, errorCount) + + finalBalance := s.getUserBalance(userID) + s.T().Logf("Final balance: %.4f (initial: %.4f)", finalBalance, initialBalance) + + s.GreaterOrEqual(finalBalance, 0.0, + "Баланс не должен быть отрицательным") + + var requestsWithTZ int + err = s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM requests_for_suppliers WHERE user_id = $1 AND generated_tz = true", + userID, + ).Scan(&requestsWithTZ) + s.NoError(err) + + s.T().Logf("Requests with generated TZ: %d", requestsWithTZ) + + s.GreaterOrEqual(requestsWithTZ, 0, + "Количество запросов с TZ должно быть >= 0") + + s.LessOrEqual(requestsWithTZ, int(successCount), + "Количество запросов с TZ не должно превышать успешные операции") +} + +func (s *IntegrationSuite) TestConcurrentRequest_MultipleUsers_CreateTZ() { + user1Email, user1Pass, user1ID := s.createUniqueTestUser("multi_user1", 500.0) + user2Email, user2Pass, user2ID := s.createUniqueTestUser("multi_user2", 500.0) + user3Email, user3Pass, user3ID := s.createUniqueTestUser("multi_user3", 500.0) + + login1, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: user1Email, Password: user1Pass, Ip: "127.0.0.1", UserAgent: "test", + }) + s.Require().NoError(err) + validate1, _ := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{AccessToken: login1.AccessToken}) + + login2, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: user2Email, Password: user2Pass, Ip: "127.0.0.1", UserAgent: "test", + }) + s.Require().NoError(err) + validate2, _ := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{AccessToken: login2.AccessToken}) + + login3, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: user3Email, Password: user3Pass, Ip: "127.0.0.1", UserAgent: "test", + }) + s.Require().NoError(err) + validate3, _ := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{AccessToken: login3.AccessToken}) + + users := []struct { + userID int64 + id int + }{ + {validate1.UserId, user1ID}, + {validate2.UserId, user2ID}, + {validate3.UserId, user3ID}, + } + + var wg sync.WaitGroup + var totalSuccess int32 + requestsPerUser := 5 + + startBarrier := make(chan struct{}) + + for _, user := range users { + for i := 0; i < requestsPerUser; i++ { + wg.Add(1) + go func(uid int64) { + defer wg.Done() + <-startBarrier + + _, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: uid, + RequestTxt: "Multi-user concurrent CreateTZ", + }) + + if err == nil { + atomic.AddInt32(&totalSuccess, 1) + } + }(user.userID) + } + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Multi-user CreateTZ total success: %d", totalSuccess) + + for _, user := range users { + balance := s.getUserBalance(user.id) + s.T().Logf("User %d final balance: %.4f", user.id, balance) + s.GreaterOrEqual(balance, 0.0, + "Баланс пользователя %d не должен быть отрицательным", user.id) + } +} + +func (s *IntegrationSuite) TestConcurrentRequest_BalanceDeduction_Consistency() { + initialBalance := 1000.0 + email, password, userID := s.createUniqueTestUser("balance_consistency", initialBalance) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + }) + s.Require().NoError(err) + + var wg sync.WaitGroup + var successCount int32 + goroutines := 10 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-startBarrier + + _, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: "Balance consistency test", + }) + + if err == nil { + atomic.AddInt32(&successCount, 1) + } + }(i) + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Successful CreateTZ operations: %d", successCount) + + finalBalance := s.getUserBalance(userID) + balanceSpent := initialBalance - finalBalance + s.T().Logf("Balance spent: %.4f", balanceSpent) + + var totalTokenCost float64 + err = s.pool.QueryRow(s.ctx, ` + SELECT COALESCE(SUM(tu.token_cost), 0) + FROM request_token_usage tu + JOIN requests_for_suppliers r ON tu.request_id = r.id + WHERE r.user_id = $1 + `, userID).Scan(&totalTokenCost) + s.NoError(err) + + s.T().Logf("Total token cost from DB: %.4f", totalTokenCost) + + s.GreaterOrEqual(finalBalance, 0.0, + "Баланс не должен быть отрицательным") + + if totalTokenCost > 0 { + tolerance := 0.01 + s.InDelta(totalTokenCost, balanceSpent, tolerance, + "Сумма token_cost должна соответствовать списанному балансу") + } +} diff --git a/tests/idempotency_test.go b/tests/idempotency_test.go new file mode 100644 index 0000000..a417ebc --- /dev/null +++ b/tests/idempotency_test.go @@ -0,0 +1,155 @@ +package tests + +import ( + "fmt" + "time" + + authpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth" + requestpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/request" +) + +func (s *IntegrationSuite) TestIdempotency_DoubleCreateTZ_CreatesTwoRequests() { + email, password, userID := s.createUniqueTestUser("idempotency_tz", 1000.0) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + }) + s.Require().NoError(err) + + requestText := "Одинаковый текст запроса для теста идемпотентности" + + resp1, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: requestText, + }) + s.Require().NoError(err) + requestID1 := resp1.RequestId + + resp2, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: requestText, + }) + s.Require().NoError(err) + requestID2 := resp2.RequestId + + s.T().Logf("Request 1 ID: %s", requestID1) + s.T().Logf("Request 2 ID: %s", requestID2) + + s.NotEqual(requestID1, requestID2, + "Два вызова CreateTZ должны создать два разных request") + + var requestCount int + err = s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM requests_for_suppliers WHERE user_id = $1 AND request_txt = $2", + userID, requestText, + ).Scan(&requestCount) + s.NoError(err) + + s.Equal(2, requestCount, + "Должно быть создано 2 запроса с одинаковым текстом") +} + +func (s *IntegrationSuite) TestIdempotency_DoubleRegister_SameInviteCode() { + inviteCode := s.createActiveInviteCode(5) + + email1 := fmt.Sprintf("double_reg1_%d@example.com", time.Now().UnixNano()) + + resp1, err := s.authClient.Register(s.ctx, &authpb.RegisterRequest{ + Email: email1, + Password: "testpassword", + Name: "User 1", + Phone: fmt.Sprintf("+1%010d", time.Now().UnixNano()%10000000000), + InviteCode: inviteCode, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + s.NotEmpty(resp1.AccessToken) + + email2 := fmt.Sprintf("double_reg2_%d@example.com", time.Now().UnixNano()) + + resp2, err := s.authClient.Register(s.ctx, &authpb.RegisterRequest{ + Email: email2, + Password: "testpassword", + Name: "User 2", + Phone: fmt.Sprintf("+2%010d", time.Now().UnixNano()%10000000000), + InviteCode: inviteCode, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + s.NotEmpty(resp2.AccessToken) + + remainingUses := s.getInviteCodeUsageCount(inviteCode) + s.T().Logf("Remaining invite uses: %d", remainingUses) + + s.Equal(3, remainingUses, + "После двух регистраций должно остаться 3 использования (5-2)") + + validate1, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: resp1.AccessToken, + }) + s.NoError(err) + + validate2, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: resp2.AccessToken, + }) + s.NoError(err) + + s.NotEqual(validate1.UserId, validate2.UserId, + "Должны быть созданы два разных пользователя") +} + +func (s *IntegrationSuite) TestIdempotency_DoubleLogout_SameToken() { + email, password, _ := s.createUniqueTestUser("double_logout", 100.0) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + accessToken := loginResp.AccessToken + + validateBefore, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: accessToken, + }) + s.NoError(err) + s.True(validateBefore.Valid) + + logout1, err := s.authClient.Logout(s.ctx, &authpb.LogoutRequest{ + AccessToken: accessToken, + }) + s.NoError(err) + s.True(logout1.Success) + + validateAfter1, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: accessToken, + }) + s.NoError(err) + s.False(validateAfter1.Valid, + "Токен должен быть невалиден после первого logout") + + logout2, err := s.authClient.Logout(s.ctx, &authpb.LogoutRequest{ + AccessToken: accessToken, + }) + s.NoError(err) + s.True(logout2.Success, + "Повторный logout должен быть успешным (идемпотентность)") + + validateAfter2, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: accessToken, + }) + s.NoError(err) + s.False(validateAfter2.Valid, + "Токен должен оставаться невалидным после повторного logout") +} diff --git a/tests/integration_suite_test.go b/tests/integration_suite_test.go index 9f3dee1..b1f745f 100644 --- a/tests/integration_suite_test.go +++ b/tests/integration_suite_test.go @@ -275,3 +275,104 @@ func (s *IntegrationSuite) createSecondTestUser() (email string, password string return email, password, userID } + +func (s *IntegrationSuite) getInviteCodeUsageCount(code int64) int { + var count int + err := s.pool.QueryRow(s.ctx, + "SELECT can_be_used_count FROM invite_codes WHERE code = $1", + code, + ).Scan(&count) + if err != nil { + return -1 + } + return count +} + +func (s *IntegrationSuite) getRequestSuppliersCount(requestID string) int { + var count int + err := s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM suppliers WHERE request_id = $1::uuid", + requestID, + ).Scan(&count) + if err != nil { + return -1 + } + return count +} + +func (s *IntegrationSuite) getUserBalance(userID int) float64 { + var balance float64 + err := s.pool.QueryRow(s.ctx, + "SELECT balance FROM users WHERE id = $1", + userID, + ).Scan(&balance) + if err != nil { + return -1 + } + return balance +} + +func (s *IntegrationSuite) getTokenUsageCount(requestID string) int { + var count int + err := s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM request_token_usage WHERE request_id = $1::uuid", + requestID, + ).Scan(&count) + if err != nil { + return -1 + } + return count +} + +func (s *IntegrationSuite) createUniqueTestUser(suffix string, balance float64) (email string, password string, userID int) { + email = fmt.Sprintf("user_%s_%d@example.com", suffix, time.Now().UnixNano()) + password = "testpassword" + + cryptoHelper := crypto.NewCrypto(testCryptoSecret) + + encryptedEmail, err := cryptoHelper.Encrypt(email) + s.Require().NoError(err) + + encryptedPhone, err := cryptoHelper.Encrypt(fmt.Sprintf("+1%d", time.Now().UnixNano()%10000000000)) + s.Require().NoError(err) + + encryptedUserName, err := cryptoHelper.Encrypt(fmt.Sprintf("User %s", suffix)) + s.Require().NoError(err) + + emailHash := cryptoHelper.EmailHash(email) + passwordHash := crypto.PasswordHash(password) + + query := ` + INSERT INTO users (email, email_hash, password_hash, phone, user_name, company_name, balance, payment_status, invites_issued, invites_limit) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + RETURNING id + ` + + err = s.pool.QueryRow(s.ctx, query, + encryptedEmail, + emailHash, + passwordHash, + encryptedPhone, + encryptedUserName, + "Test Company", + balance, + "active", + 0, + 10, + ).Scan(&userID) + s.Require().NoError(err) + + return email, password, userID +} + +func (s *IntegrationSuite) isInviteCodeActive(code int64) bool { + var isActive bool + err := s.pool.QueryRow(s.ctx, + "SELECT is_active FROM invite_codes WHERE code = $1", + code, + ).Scan(&isActive) + if err != nil { + return false + } + return isActive +} diff --git a/tests/transaction_rollback_test.go b/tests/transaction_rollback_test.go new file mode 100644 index 0000000..cde240a --- /dev/null +++ b/tests/transaction_rollback_test.go @@ -0,0 +1,194 @@ +package tests + +import ( + "sync" + "sync/atomic" + + authpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth" + requestpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/request" +) + +func (s *IntegrationSuite) TestTransaction_CreateTZ_InsufficientBalance_Rollback() { + email, password, userID := s.createUniqueTestUser("insufficient_tz", 0.001) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + }) + s.Require().NoError(err) + + initialBalance := s.getUserBalance(userID) + s.T().Logf("Initial balance: %.4f", initialBalance) + + _, err = s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: "Тест с недостаточным балансом", + }) + + if err != nil { + s.T().Logf("CreateTZ failed as expected: %v", err) + + finalBalance := s.getUserBalance(userID) + s.T().Logf("Final balance after failed CreateTZ: %.4f", finalBalance) + + s.GreaterOrEqual(finalBalance, 0.0, + "Баланс не должен быть отрицательным после rollback") + } +} + +func (s *IntegrationSuite) TestTransaction_ApproveTZ_InsufficientBalance_NoSuppliers() { + email, password, userID := s.createUniqueTestUser("approve_insufficient", 1000.0) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + }) + s.Require().NoError(err) + + createResp, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: "Тест approve с недостаточным балансом", + }) + s.Require().NoError(err) + requestID := createResp.RequestId + + _, err = s.pool.Exec(s.ctx, "UPDATE users SET balance = 0.001 WHERE id = $1", userID) + s.Require().NoError(err) + + suppliersBeforeApprove := s.getRequestSuppliersCount(requestID) + s.T().Logf("Suppliers before ApproveTZ: %d", suppliersBeforeApprove) + + _, err = s.requestClient.ApproveTZ(s.ctx, &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: "Утвержденное ТЗ", + UserId: validateResp.UserId, + }) + + if err != nil { + s.T().Logf("ApproveTZ failed as expected: %v", err) + + suppliersAfterApprove := s.getRequestSuppliersCount(requestID) + s.T().Logf("Suppliers after failed ApproveTZ: %d", suppliersAfterApprove) + + finalBalance := s.getUserBalance(userID) + s.GreaterOrEqual(finalBalance, 0.0, + "Баланс не должен быть отрицательным") + } +} + +func (s *IntegrationSuite) TestTransaction_ConcurrentCreateTZ_BalanceAtomicity() { + email, password, userID := s.createUniqueTestUser("concurrent_tz", 100.0) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + }) + s.Require().NoError(err) + + var wg sync.WaitGroup + var successCount int32 + var errorCount int32 + goroutines := 10 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + <-startBarrier + + _, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: "Параллельный тест CreateTZ", + }) + + if err == nil { + atomic.AddInt32(&successCount, 1) + } else { + atomic.AddInt32(&errorCount, 1) + } + }(i) + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Concurrent CreateTZ - Success: %d, Errors: %d", successCount, errorCount) + + finalBalance := s.getUserBalance(userID) + s.T().Logf("Final balance: %.4f", finalBalance) + + s.GreaterOrEqual(finalBalance, 0.0, + "Баланс не должен быть отрицательным после параллельных операций") +} + +func (s *IntegrationSuite) TestTransaction_TokenUsage_BalanceConsistency() { + email, password, userID := s.createUniqueTestUser("token_consistency", 1000.0) + + initialBalance := s.getUserBalance(userID) + s.T().Logf("Initial balance: %.4f", initialBalance) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + }) + s.Require().NoError(err) + + createResp, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: "Тест consistency token_usage и balance", + }) + s.Require().NoError(err) + requestID := createResp.RequestId + + tokenUsageCount := s.getTokenUsageCount(requestID) + s.T().Logf("Token usage records for request: %d", tokenUsageCount) + + finalBalance := s.getUserBalance(userID) + balanceDelta := initialBalance - finalBalance + s.T().Logf("Balance delta: %.4f", balanceDelta) + + if tokenUsageCount > 0 { + s.Greater(balanceDelta, 0.0, + "Баланс должен уменьшиться при наличии token_usage записей") + } + + var totalTokenCost float64 + err = s.pool.QueryRow(s.ctx, + "SELECT COALESCE(SUM(token_cost), 0) FROM request_token_usage WHERE request_id = $1::uuid", + requestID, + ).Scan(&totalTokenCost) + s.NoError(err) + + s.T().Logf("Total token cost from DB: %.4f, Balance delta: %.4f", totalTokenCost, balanceDelta) +} diff --git a/tests/worker_concurrent_test.go b/tests/worker_concurrent_test.go new file mode 100644 index 0000000..53e410a --- /dev/null +++ b/tests/worker_concurrent_test.go @@ -0,0 +1,181 @@ +package tests + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" +) + +func (s *IntegrationSuite) TestWorkerConcurrent_SessionCleanup_MassExpired() { + _, _, userID := s.createUniqueTestUser("session_cleanup", 100.0) + + expiredTime := time.Now().Add(-24 * time.Hour) + validTime := time.Now().Add(24 * time.Hour) + + expiredCount := 100 + validCount := 50 + + for i := 0; i < expiredCount; i++ { + _, err := s.pool.Exec(s.ctx, ` + INSERT INTO sessions (user_id, access_token, refresh_token, ip, user_agent, expires_at) + VALUES ($1, $2, $3, '127.0.0.1', 'test-agent', $4) + `, userID, + fmt.Sprintf("expired_access_%d_%s", i, uuid.New().String()), + fmt.Sprintf("expired_refresh_%d_%s", i, uuid.New().String()), + expiredTime, + ) + s.Require().NoError(err) + } + + for i := 0; i < validCount; i++ { + _, err := s.pool.Exec(s.ctx, ` + INSERT INTO sessions (user_id, access_token, refresh_token, ip, user_agent, expires_at) + VALUES ($1, $2, $3, '127.0.0.1', 'test-agent', $4) + `, userID, + fmt.Sprintf("valid_access_%d_%s", i, uuid.New().String()), + fmt.Sprintf("valid_refresh_%d_%s", i, uuid.New().String()), + validTime, + ) + s.Require().NoError(err) + } + + var totalBefore int + err := s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM sessions WHERE user_id = $1", userID, + ).Scan(&totalBefore) + s.Require().NoError(err) + s.T().Logf("Sessions before cleanup: %d", totalBefore) + + var wg sync.WaitGroup + var totalDeleted int32 + goroutines := 10 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-startBarrier + + result, err := s.pool.Exec(s.ctx, ` + DELETE FROM sessions + WHERE expires_at < now() + OR (revoked_at IS NOT NULL AND revoked_at < now() - interval '30 days') + `) + + if err == nil { + atomic.AddInt32(&totalDeleted, int32(result.RowsAffected())) + } + }() + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Total deleted by concurrent cleanup: %d", totalDeleted) + + var validRemaining int + err = s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM sessions WHERE user_id = $1 AND expires_at > now()", userID, + ).Scan(&validRemaining) + s.NoError(err) + + s.T().Logf("Valid sessions remaining: %d (expected: %d)", validRemaining, validCount) + + s.Equal(validCount, validRemaining, + "Все валидные сессии должны остаться после cleanup") + + s.GreaterOrEqual(int(totalDeleted), expiredCount, + "Все истекшие сессии должны быть удалены") +} + +func (s *IntegrationSuite) TestWorkerConcurrent_InviteCleanup_MassExpired() { + _, _, userID := s.createUniqueTestUser("invite_cleanup", 100.0) + + _, err := s.pool.Exec(s.ctx, "UPDATE users SET invites_limit = 200 WHERE id = $1", userID) + s.Require().NoError(err) + + expiredTime := time.Now().Add(-24 * time.Hour) + validTime := time.Now().Add(24 * time.Hour) + + expiredCount := 100 + validCount := 50 + + for i := 0; i < expiredCount; i++ { + code := int64(30000000 + i) + _, err := s.pool.Exec(s.ctx, ` + INSERT INTO invite_codes (user_id, code, can_be_used_count, expires_at, is_active) + VALUES ($1, $2, 5, $3, true) + `, userID, code, expiredTime) + s.Require().NoError(err) + } + + for i := 0; i < validCount; i++ { + code := int64(40000000 + i) + _, err := s.pool.Exec(s.ctx, ` + INSERT INTO invite_codes (user_id, code, can_be_used_count, expires_at, is_active) + VALUES ($1, $2, 5, $3, true) + `, userID, code, validTime) + s.Require().NoError(err) + } + + var activeBefore int + err = s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM invite_codes WHERE user_id = $1 AND is_active = true", userID, + ).Scan(&activeBefore) + s.Require().NoError(err) + s.T().Logf("Active invites before cleanup: %d", activeBefore) + + var wg sync.WaitGroup + var totalDeactivated int32 + goroutines := 10 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-startBarrier + + result, err := s.pool.Exec(s.ctx, ` + UPDATE invite_codes + SET is_active = false + WHERE expires_at < now() AND is_active = true + `) + + if err == nil { + atomic.AddInt32(&totalDeactivated, int32(result.RowsAffected())) + } + }() + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Total deactivated by concurrent cleanup: %d", totalDeactivated) + + var activeRemaining int + err = s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM invite_codes WHERE user_id = $1 AND is_active = true", userID, + ).Scan(&activeRemaining) + s.NoError(err) + + s.T().Logf("Active invites remaining: %d (expected: %d)", activeRemaining, validCount) + + s.Equal(validCount, activeRemaining, + "Все валидные инвайты должны остаться активными после cleanup") + + var expiredStillActive int + err = s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM invite_codes WHERE user_id = $1 AND expires_at < now() AND is_active = true", userID, + ).Scan(&expiredStillActive) + s.NoError(err) + + s.Equal(0, expiredStillActive, + "Не должно остаться активных истекших инвайтов") +}