This commit is contained in:
@@ -146,18 +146,27 @@ func (r *inviteRepository) DecrementCanBeUsedCountTx(ctx context.Context, tx pgx
|
||||
query := r.qb.Update("invite_codes").
|
||||
Set("can_be_used_count", sq.Expr("can_be_used_count - 1")).
|
||||
Set("is_active", sq.Expr("CASE WHEN can_be_used_count - 1 <= 0 THEN false ELSE is_active END")).
|
||||
Where(sq.Eq{"code": code})
|
||||
Where(sq.And{
|
||||
sq.Eq{"code": code},
|
||||
sq.Expr("can_be_used_count > 0"),
|
||||
sq.Eq{"is_active": true},
|
||||
sq.Expr("expires_at > now()"),
|
||||
})
|
||||
|
||||
sqlQuery, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec(ctx, sqlQuery, args...)
|
||||
result, err := tx.Exec(ctx, sqlQuery, args...)
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to decrement can_be_used_count", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return errs.NewBusinessError(errs.InviteInvalidOrExpired, "invite code is invalid, expired, or exhausted")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
|
||||
@@ -11,7 +11,6 @@ import (
|
||||
)
|
||||
|
||||
type Claims struct {
|
||||
Sub string `json:"sub"`
|
||||
Type string `json:"type"`
|
||||
jwt.RegisteredClaims
|
||||
}
|
||||
@@ -19,9 +18,9 @@ type Claims struct {
|
||||
func GenerateAccessToken(userID int, secret string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := Claims{
|
||||
Sub: strconv.Itoa(userID),
|
||||
Type: "access",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: strconv.Itoa(userID),
|
||||
ID: uuid.New().String(),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(15 * time.Minute)),
|
||||
@@ -35,9 +34,9 @@ func GenerateAccessToken(userID int, secret string) (string, error) {
|
||||
func GenerateRefreshToken(userID int, secret string) (string, error) {
|
||||
now := time.Now()
|
||||
claims := Claims{
|
||||
Sub: strconv.Itoa(userID),
|
||||
Type: "refresh",
|
||||
RegisteredClaims: jwt.RegisteredClaims{
|
||||
Subject: strconv.Itoa(userID),
|
||||
ID: uuid.New().String(),
|
||||
IssuedAt: jwt.NewNumericDate(now),
|
||||
ExpiresAt: jwt.NewNumericDate(now.Add(30 * 24 * time.Hour)),
|
||||
@@ -73,7 +72,7 @@ func GetUserIDFromToken(tokenString, secret string) (int, error) {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
userID, err := strconv.Atoi(claims.Sub)
|
||||
userID, err := strconv.Atoi(claims.Subject)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("invalid user ID in token: %w", err)
|
||||
}
|
||||
|
||||
243
tests/concurrent_ownership_test.go
Normal file
243
tests/concurrent_ownership_test.go
Normal file
@@ -0,0 +1,243 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
authpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth"
|
||||
requestpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/request"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
func (s *IntegrationSuite) TestConcurrentOwnership_User2TriesApproveTZ_WhileUser1Creates() {
|
||||
email1, password1, _ := s.createUniqueTestUser("owner1", 1000.0)
|
||||
email2, password2, _ := s.createUniqueTestUser("attacker1", 1000.0)
|
||||
|
||||
login1, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{
|
||||
Email: email1,
|
||||
Password: password1,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
validate1, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: login1.AccessToken,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
user1ID := validate1.UserId
|
||||
|
||||
login2, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{
|
||||
Email: email2,
|
||||
Password: password2,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
validate2, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: login2.AccessToken,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
user2ID := validate2.UserId
|
||||
|
||||
createResp, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{
|
||||
UserId: user1ID,
|
||||
RequestTxt: "Request от User1 для теста ownership",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
requestID := createResp.RequestId
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var user1Success, user2Denied int32
|
||||
|
||||
startBarrier := make(chan struct{})
|
||||
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-startBarrier
|
||||
|
||||
_, err := s.requestClient.ApproveTZ(s.ctx, &requestpb.ApproveTZRequest{
|
||||
RequestId: requestID,
|
||||
FinalTz: "User1 approves",
|
||||
UserId: user1ID,
|
||||
})
|
||||
if err == nil {
|
||||
atomic.AddInt32(&user1Success, 1)
|
||||
}
|
||||
}()
|
||||
|
||||
for i := 0; i < 5; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-startBarrier
|
||||
|
||||
_, err := s.requestClient.ApproveTZ(s.ctx, &requestpb.ApproveTZRequest{
|
||||
RequestId: requestID,
|
||||
FinalTz: "User2 tries to approve",
|
||||
UserId: user2ID,
|
||||
})
|
||||
if err != nil {
|
||||
st, ok := status.FromError(err)
|
||||
if ok && st.Code() == codes.PermissionDenied {
|
||||
atomic.AddInt32(&user2Denied, 1)
|
||||
}
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
close(startBarrier)
|
||||
wg.Wait()
|
||||
|
||||
s.T().Logf("User1 success: %d, User2 denied: %d", user1Success, user2Denied)
|
||||
|
||||
s.Equal(int32(5), user2Denied,
|
||||
"Все попытки User2 должны быть отклонены с PermissionDenied")
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) TestConcurrentOwnership_ConcurrentApproveTZ_SameRequest() {
|
||||
email, password, _ := s.createUniqueTestUser("concurrent_approve", 1000.0)
|
||||
|
||||
loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{
|
||||
Email: email,
|
||||
Password: password,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: loginResp.AccessToken,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
userID := validateResp.UserId
|
||||
|
||||
createResp, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{
|
||||
UserId: userID,
|
||||
RequestTxt: "Request для concurrent ApproveTZ",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
requestID := createResp.RequestId
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var successCount int32
|
||||
goroutines := 5
|
||||
|
||||
startBarrier := make(chan struct{})
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
<-startBarrier
|
||||
|
||||
_, err := s.requestClient.ApproveTZ(s.ctx, &requestpb.ApproveTZRequest{
|
||||
RequestId: requestID,
|
||||
FinalTz: "Concurrent approve attempt",
|
||||
UserId: userID,
|
||||
})
|
||||
if err == nil {
|
||||
atomic.AddInt32(&successCount, 1)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
close(startBarrier)
|
||||
wg.Wait()
|
||||
|
||||
s.T().Logf("Concurrent ApproveTZ success count: %d", successCount)
|
||||
|
||||
suppliersCount := s.getRequestSuppliersCount(requestID)
|
||||
s.T().Logf("Total suppliers for request: %d", suppliersCount)
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) TestConcurrentOwnership_SessionIsolation_AfterLogout() {
|
||||
email1, password1, _ := s.createUniqueTestUser("session_iso1", 1000.0)
|
||||
email2, password2, _ := s.createUniqueTestUser("session_iso2", 1000.0)
|
||||
|
||||
login1, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{
|
||||
Email: email1,
|
||||
Password: password1,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
user1Token := login1.AccessToken
|
||||
|
||||
validate1, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: user1Token,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.True(validate1.Valid)
|
||||
|
||||
login2, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{
|
||||
Email: email2,
|
||||
Password: password2,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
user2Token := login2.AccessToken
|
||||
|
||||
_, err = s.authClient.Logout(s.ctx, &authpb.LogoutRequest{
|
||||
AccessToken: user1Token,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
validate1After, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: user1Token,
|
||||
})
|
||||
s.NoError(err)
|
||||
s.False(validate1After.Valid,
|
||||
"Токен User1 должен быть невалиден после logout")
|
||||
|
||||
validate2After, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: user2Token,
|
||||
})
|
||||
s.NoError(err)
|
||||
s.True(validate2After.Valid,
|
||||
"Токен User2 должен оставаться валидным после logout User1")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var user1Invalid, user2Valid int32
|
||||
goroutines := 10
|
||||
|
||||
startBarrier := make(chan struct{})
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
<-startBarrier
|
||||
|
||||
if idx%2 == 0 {
|
||||
resp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: user1Token,
|
||||
})
|
||||
if err == nil && !resp.Valid {
|
||||
atomic.AddInt32(&user1Invalid, 1)
|
||||
}
|
||||
} else {
|
||||
resp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: user2Token,
|
||||
})
|
||||
if err == nil && resp.Valid {
|
||||
atomic.AddInt32(&user2Valid, 1)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
close(startBarrier)
|
||||
wg.Wait()
|
||||
|
||||
s.T().Logf("Session isolation - User1 invalid: %d, User2 valid: %d", user1Invalid, user2Valid)
|
||||
|
||||
s.Equal(int32(goroutines/2), user1Invalid,
|
||||
"Все проверки токена User1 должны показать invalid")
|
||||
s.Equal(int32(goroutines/2), user2Valid,
|
||||
"Все проверки токена User2 должны показать valid")
|
||||
}
|
||||
178
tests/concurrent_registration_test.go
Normal file
178
tests/concurrent_registration_test.go
Normal file
@@ -0,0 +1,178 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
authpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth"
|
||||
)
|
||||
|
||||
func (s *IntegrationSuite) TestConcurrent_Registration_WithSingleInviteCode() {
|
||||
maxUses := 3
|
||||
inviteCode := s.createActiveInviteCode(maxUses)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var successCount int32
|
||||
var errorCount int32
|
||||
goroutines := 20
|
||||
|
||||
startBarrier := make(chan struct{})
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-startBarrier
|
||||
|
||||
email := fmt.Sprintf("concurrent_reg_%d_%d@example.com", idx, time.Now().UnixNano())
|
||||
|
||||
_, err := s.authClient.Register(s.ctx, &authpb.RegisterRequest{
|
||||
Email: email,
|
||||
Password: "testpassword123",
|
||||
Name: fmt.Sprintf("User %d", idx),
|
||||
Phone: fmt.Sprintf("+1%010d", idx),
|
||||
InviteCode: inviteCode,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "integration-test",
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
atomic.AddInt32(&successCount, 1)
|
||||
} else {
|
||||
atomic.AddInt32(&errorCount, 1)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
close(startBarrier)
|
||||
wg.Wait()
|
||||
|
||||
s.T().Logf("Registration results - Success: %d, Errors: %d", successCount, errorCount)
|
||||
|
||||
s.LessOrEqual(int(successCount), maxUses,
|
||||
"Количество успешных регистраций (%d) не должно превышать лимит invite-кода (%d)", successCount, maxUses)
|
||||
|
||||
remainingUses := s.getInviteCodeUsageCount(inviteCode)
|
||||
s.T().Logf("Remaining invite code uses: %d", remainingUses)
|
||||
|
||||
s.Equal(maxUses-int(successCount), remainingUses,
|
||||
"Оставшееся количество использований должно соответствовать успешным регистрациям")
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) TestConcurrent_Registration_InviteCodeDeactivation() {
|
||||
maxUses := 2
|
||||
inviteCode := s.createActiveInviteCode(maxUses)
|
||||
|
||||
s.True(s.isInviteCodeActive(inviteCode), "Invite code должен быть активен изначально")
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var successCount int32
|
||||
goroutines := 10
|
||||
|
||||
startBarrier := make(chan struct{})
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-startBarrier
|
||||
|
||||
email := fmt.Sprintf("deactivation_test_%d_%d@example.com", idx, time.Now().UnixNano())
|
||||
|
||||
_, err := s.authClient.Register(s.ctx, &authpb.RegisterRequest{
|
||||
Email: email,
|
||||
Password: "testpassword123",
|
||||
Name: fmt.Sprintf("User %d", idx),
|
||||
Phone: fmt.Sprintf("+2%010d", idx),
|
||||
InviteCode: inviteCode,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "integration-test",
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
atomic.AddInt32(&successCount, 1)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
close(startBarrier)
|
||||
wg.Wait()
|
||||
|
||||
s.T().Logf("Registration success count: %d", successCount)
|
||||
|
||||
s.LessOrEqual(int(successCount), maxUses,
|
||||
"Не должно быть больше %d успешных регистраций", maxUses)
|
||||
|
||||
remainingUses := s.getInviteCodeUsageCount(inviteCode)
|
||||
s.GreaterOrEqual(remainingUses, 0,
|
||||
"Количество использований не должно быть отрицательным")
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) TestConcurrent_Registration_MultipleInviteCodes() {
|
||||
inviteCode1 := s.createActiveInviteCode(2)
|
||||
inviteCode2 := s.createActiveInviteCode(2)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var success1, success2 int32
|
||||
goroutines := 10
|
||||
|
||||
startBarrier := make(chan struct{})
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-startBarrier
|
||||
|
||||
var code int64
|
||||
if idx%2 == 0 {
|
||||
code = inviteCode1
|
||||
} else {
|
||||
code = inviteCode2
|
||||
}
|
||||
|
||||
email := fmt.Sprintf("multi_invite_%d_%d@example.com", idx, time.Now().UnixNano())
|
||||
|
||||
_, err := s.authClient.Register(s.ctx, &authpb.RegisterRequest{
|
||||
Email: email,
|
||||
Password: "testpassword123",
|
||||
Name: fmt.Sprintf("User %d", idx),
|
||||
Phone: fmt.Sprintf("+3%010d", idx),
|
||||
InviteCode: code,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "integration-test",
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
if code == inviteCode1 {
|
||||
atomic.AddInt32(&success1, 1)
|
||||
} else {
|
||||
atomic.AddInt32(&success2, 1)
|
||||
}
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
close(startBarrier)
|
||||
wg.Wait()
|
||||
|
||||
s.T().Logf("Multi-invite results - Code1: %d, Code2: %d", success1, success2)
|
||||
|
||||
s.LessOrEqual(int(success1), 2,
|
||||
"Invite code 1 не должен превышать лимит")
|
||||
s.LessOrEqual(int(success2), 2,
|
||||
"Invite code 2 не должен превышать лимит")
|
||||
|
||||
remaining1 := s.getInviteCodeUsageCount(inviteCode1)
|
||||
remaining2 := s.getInviteCodeUsageCount(inviteCode2)
|
||||
|
||||
s.Equal(2-int(success1), remaining1,
|
||||
"Остаток invite code 1 должен соответствовать успешным регистрациям")
|
||||
s.Equal(2-int(success2), remaining2,
|
||||
"Остаток invite code 2 должен соответствовать успешным регистрациям")
|
||||
}
|
||||
219
tests/concurrent_request_test.go
Normal file
219
tests/concurrent_request_test.go
Normal file
@@ -0,0 +1,219 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
authpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth"
|
||||
requestpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/request"
|
||||
)
|
||||
|
||||
func (s *IntegrationSuite) TestConcurrentRequest_CreateTZ_LimitedBalance() {
|
||||
initialBalance := 50.0
|
||||
email, password, userID := s.createUniqueTestUser("limited_balance_tz", initialBalance)
|
||||
|
||||
loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{
|
||||
Email: email,
|
||||
Password: password,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: loginResp.AccessToken,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var successCount int32
|
||||
var errorCount int32
|
||||
goroutines := 20
|
||||
|
||||
startBarrier := make(chan struct{})
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
<-startBarrier
|
||||
|
||||
_, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{
|
||||
UserId: validateResp.UserId,
|
||||
RequestTxt: "Параллельный CreateTZ с ограниченным балансом",
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
atomic.AddInt32(&successCount, 1)
|
||||
} else {
|
||||
atomic.AddInt32(&errorCount, 1)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
close(startBarrier)
|
||||
wg.Wait()
|
||||
|
||||
s.T().Logf("CreateTZ with limited balance - Success: %d, Errors: %d", successCount, errorCount)
|
||||
|
||||
finalBalance := s.getUserBalance(userID)
|
||||
s.T().Logf("Final balance: %.4f (initial: %.4f)", finalBalance, initialBalance)
|
||||
|
||||
s.GreaterOrEqual(finalBalance, 0.0,
|
||||
"Баланс не должен быть отрицательным")
|
||||
|
||||
var requestsWithTZ int
|
||||
err = s.pool.QueryRow(s.ctx,
|
||||
"SELECT COUNT(*) FROM requests_for_suppliers WHERE user_id = $1 AND generated_tz = true",
|
||||
userID,
|
||||
).Scan(&requestsWithTZ)
|
||||
s.NoError(err)
|
||||
|
||||
s.T().Logf("Requests with generated TZ: %d", requestsWithTZ)
|
||||
|
||||
s.GreaterOrEqual(requestsWithTZ, 0,
|
||||
"Количество запросов с TZ должно быть >= 0")
|
||||
|
||||
s.LessOrEqual(requestsWithTZ, int(successCount),
|
||||
"Количество запросов с TZ не должно превышать успешные операции")
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) TestConcurrentRequest_MultipleUsers_CreateTZ() {
|
||||
user1Email, user1Pass, user1ID := s.createUniqueTestUser("multi_user1", 500.0)
|
||||
user2Email, user2Pass, user2ID := s.createUniqueTestUser("multi_user2", 500.0)
|
||||
user3Email, user3Pass, user3ID := s.createUniqueTestUser("multi_user3", 500.0)
|
||||
|
||||
login1, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{
|
||||
Email: user1Email, Password: user1Pass, Ip: "127.0.0.1", UserAgent: "test",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
validate1, _ := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{AccessToken: login1.AccessToken})
|
||||
|
||||
login2, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{
|
||||
Email: user2Email, Password: user2Pass, Ip: "127.0.0.1", UserAgent: "test",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
validate2, _ := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{AccessToken: login2.AccessToken})
|
||||
|
||||
login3, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{
|
||||
Email: user3Email, Password: user3Pass, Ip: "127.0.0.1", UserAgent: "test",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
validate3, _ := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{AccessToken: login3.AccessToken})
|
||||
|
||||
users := []struct {
|
||||
userID int64
|
||||
id int
|
||||
}{
|
||||
{validate1.UserId, user1ID},
|
||||
{validate2.UserId, user2ID},
|
||||
{validate3.UserId, user3ID},
|
||||
}
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var totalSuccess int32
|
||||
requestsPerUser := 5
|
||||
|
||||
startBarrier := make(chan struct{})
|
||||
|
||||
for _, user := range users {
|
||||
for i := 0; i < requestsPerUser; i++ {
|
||||
wg.Add(1)
|
||||
go func(uid int64) {
|
||||
defer wg.Done()
|
||||
<-startBarrier
|
||||
|
||||
_, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{
|
||||
UserId: uid,
|
||||
RequestTxt: "Multi-user concurrent CreateTZ",
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
atomic.AddInt32(&totalSuccess, 1)
|
||||
}
|
||||
}(user.userID)
|
||||
}
|
||||
}
|
||||
|
||||
close(startBarrier)
|
||||
wg.Wait()
|
||||
|
||||
s.T().Logf("Multi-user CreateTZ total success: %d", totalSuccess)
|
||||
|
||||
for _, user := range users {
|
||||
balance := s.getUserBalance(user.id)
|
||||
s.T().Logf("User %d final balance: %.4f", user.id, balance)
|
||||
s.GreaterOrEqual(balance, 0.0,
|
||||
"Баланс пользователя %d не должен быть отрицательным", user.id)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) TestConcurrentRequest_BalanceDeduction_Consistency() {
|
||||
initialBalance := 1000.0
|
||||
email, password, userID := s.createUniqueTestUser("balance_consistency", initialBalance)
|
||||
|
||||
loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{
|
||||
Email: email,
|
||||
Password: password,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: loginResp.AccessToken,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var successCount int32
|
||||
goroutines := 10
|
||||
|
||||
startBarrier := make(chan struct{})
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
<-startBarrier
|
||||
|
||||
_, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{
|
||||
UserId: validateResp.UserId,
|
||||
RequestTxt: "Balance consistency test",
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
atomic.AddInt32(&successCount, 1)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
close(startBarrier)
|
||||
wg.Wait()
|
||||
|
||||
s.T().Logf("Successful CreateTZ operations: %d", successCount)
|
||||
|
||||
finalBalance := s.getUserBalance(userID)
|
||||
balanceSpent := initialBalance - finalBalance
|
||||
s.T().Logf("Balance spent: %.4f", balanceSpent)
|
||||
|
||||
var totalTokenCost float64
|
||||
err = s.pool.QueryRow(s.ctx, `
|
||||
SELECT COALESCE(SUM(tu.token_cost), 0)
|
||||
FROM request_token_usage tu
|
||||
JOIN requests_for_suppliers r ON tu.request_id = r.id
|
||||
WHERE r.user_id = $1
|
||||
`, userID).Scan(&totalTokenCost)
|
||||
s.NoError(err)
|
||||
|
||||
s.T().Logf("Total token cost from DB: %.4f", totalTokenCost)
|
||||
|
||||
s.GreaterOrEqual(finalBalance, 0.0,
|
||||
"Баланс не должен быть отрицательным")
|
||||
|
||||
if totalTokenCost > 0 {
|
||||
tolerance := 0.01
|
||||
s.InDelta(totalTokenCost, balanceSpent, tolerance,
|
||||
"Сумма token_cost должна соответствовать списанному балансу")
|
||||
}
|
||||
}
|
||||
155
tests/idempotency_test.go
Normal file
155
tests/idempotency_test.go
Normal file
@@ -0,0 +1,155 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
authpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth"
|
||||
requestpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/request"
|
||||
)
|
||||
|
||||
func (s *IntegrationSuite) TestIdempotency_DoubleCreateTZ_CreatesTwoRequests() {
|
||||
email, password, userID := s.createUniqueTestUser("idempotency_tz", 1000.0)
|
||||
|
||||
loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{
|
||||
Email: email,
|
||||
Password: password,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: loginResp.AccessToken,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
requestText := "Одинаковый текст запроса для теста идемпотентности"
|
||||
|
||||
resp1, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{
|
||||
UserId: validateResp.UserId,
|
||||
RequestTxt: requestText,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
requestID1 := resp1.RequestId
|
||||
|
||||
resp2, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{
|
||||
UserId: validateResp.UserId,
|
||||
RequestTxt: requestText,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
requestID2 := resp2.RequestId
|
||||
|
||||
s.T().Logf("Request 1 ID: %s", requestID1)
|
||||
s.T().Logf("Request 2 ID: %s", requestID2)
|
||||
|
||||
s.NotEqual(requestID1, requestID2,
|
||||
"Два вызова CreateTZ должны создать два разных request")
|
||||
|
||||
var requestCount int
|
||||
err = s.pool.QueryRow(s.ctx,
|
||||
"SELECT COUNT(*) FROM requests_for_suppliers WHERE user_id = $1 AND request_txt = $2",
|
||||
userID, requestText,
|
||||
).Scan(&requestCount)
|
||||
s.NoError(err)
|
||||
|
||||
s.Equal(2, requestCount,
|
||||
"Должно быть создано 2 запроса с одинаковым текстом")
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) TestIdempotency_DoubleRegister_SameInviteCode() {
|
||||
inviteCode := s.createActiveInviteCode(5)
|
||||
|
||||
email1 := fmt.Sprintf("double_reg1_%d@example.com", time.Now().UnixNano())
|
||||
|
||||
resp1, err := s.authClient.Register(s.ctx, &authpb.RegisterRequest{
|
||||
Email: email1,
|
||||
Password: "testpassword",
|
||||
Name: "User 1",
|
||||
Phone: fmt.Sprintf("+1%010d", time.Now().UnixNano()%10000000000),
|
||||
InviteCode: inviteCode,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.NotEmpty(resp1.AccessToken)
|
||||
|
||||
email2 := fmt.Sprintf("double_reg2_%d@example.com", time.Now().UnixNano())
|
||||
|
||||
resp2, err := s.authClient.Register(s.ctx, &authpb.RegisterRequest{
|
||||
Email: email2,
|
||||
Password: "testpassword",
|
||||
Name: "User 2",
|
||||
Phone: fmt.Sprintf("+2%010d", time.Now().UnixNano()%10000000000),
|
||||
InviteCode: inviteCode,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
s.NotEmpty(resp2.AccessToken)
|
||||
|
||||
remainingUses := s.getInviteCodeUsageCount(inviteCode)
|
||||
s.T().Logf("Remaining invite uses: %d", remainingUses)
|
||||
|
||||
s.Equal(3, remainingUses,
|
||||
"После двух регистраций должно остаться 3 использования (5-2)")
|
||||
|
||||
validate1, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: resp1.AccessToken,
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
validate2, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: resp2.AccessToken,
|
||||
})
|
||||
s.NoError(err)
|
||||
|
||||
s.NotEqual(validate1.UserId, validate2.UserId,
|
||||
"Должны быть созданы два разных пользователя")
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) TestIdempotency_DoubleLogout_SameToken() {
|
||||
email, password, _ := s.createUniqueTestUser("double_logout", 100.0)
|
||||
|
||||
loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{
|
||||
Email: email,
|
||||
Password: password,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
accessToken := loginResp.AccessToken
|
||||
|
||||
validateBefore, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: accessToken,
|
||||
})
|
||||
s.NoError(err)
|
||||
s.True(validateBefore.Valid)
|
||||
|
||||
logout1, err := s.authClient.Logout(s.ctx, &authpb.LogoutRequest{
|
||||
AccessToken: accessToken,
|
||||
})
|
||||
s.NoError(err)
|
||||
s.True(logout1.Success)
|
||||
|
||||
validateAfter1, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: accessToken,
|
||||
})
|
||||
s.NoError(err)
|
||||
s.False(validateAfter1.Valid,
|
||||
"Токен должен быть невалиден после первого logout")
|
||||
|
||||
logout2, err := s.authClient.Logout(s.ctx, &authpb.LogoutRequest{
|
||||
AccessToken: accessToken,
|
||||
})
|
||||
s.NoError(err)
|
||||
s.True(logout2.Success,
|
||||
"Повторный logout должен быть успешным (идемпотентность)")
|
||||
|
||||
validateAfter2, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: accessToken,
|
||||
})
|
||||
s.NoError(err)
|
||||
s.False(validateAfter2.Valid,
|
||||
"Токен должен оставаться невалидным после повторного logout")
|
||||
}
|
||||
@@ -275,3 +275,104 @@ func (s *IntegrationSuite) createSecondTestUser() (email string, password string
|
||||
|
||||
return email, password, userID
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) getInviteCodeUsageCount(code int64) int {
|
||||
var count int
|
||||
err := s.pool.QueryRow(s.ctx,
|
||||
"SELECT can_be_used_count FROM invite_codes WHERE code = $1",
|
||||
code,
|
||||
).Scan(&count)
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) getRequestSuppliersCount(requestID string) int {
|
||||
var count int
|
||||
err := s.pool.QueryRow(s.ctx,
|
||||
"SELECT COUNT(*) FROM suppliers WHERE request_id = $1::uuid",
|
||||
requestID,
|
||||
).Scan(&count)
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) getUserBalance(userID int) float64 {
|
||||
var balance float64
|
||||
err := s.pool.QueryRow(s.ctx,
|
||||
"SELECT balance FROM users WHERE id = $1",
|
||||
userID,
|
||||
).Scan(&balance)
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
return balance
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) getTokenUsageCount(requestID string) int {
|
||||
var count int
|
||||
err := s.pool.QueryRow(s.ctx,
|
||||
"SELECT COUNT(*) FROM request_token_usage WHERE request_id = $1::uuid",
|
||||
requestID,
|
||||
).Scan(&count)
|
||||
if err != nil {
|
||||
return -1
|
||||
}
|
||||
return count
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) createUniqueTestUser(suffix string, balance float64) (email string, password string, userID int) {
|
||||
email = fmt.Sprintf("user_%s_%d@example.com", suffix, time.Now().UnixNano())
|
||||
password = "testpassword"
|
||||
|
||||
cryptoHelper := crypto.NewCrypto(testCryptoSecret)
|
||||
|
||||
encryptedEmail, err := cryptoHelper.Encrypt(email)
|
||||
s.Require().NoError(err)
|
||||
|
||||
encryptedPhone, err := cryptoHelper.Encrypt(fmt.Sprintf("+1%d", time.Now().UnixNano()%10000000000))
|
||||
s.Require().NoError(err)
|
||||
|
||||
encryptedUserName, err := cryptoHelper.Encrypt(fmt.Sprintf("User %s", suffix))
|
||||
s.Require().NoError(err)
|
||||
|
||||
emailHash := cryptoHelper.EmailHash(email)
|
||||
passwordHash := crypto.PasswordHash(password)
|
||||
|
||||
query := `
|
||||
INSERT INTO users (email, email_hash, password_hash, phone, user_name, company_name, balance, payment_status, invites_issued, invites_limit)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
RETURNING id
|
||||
`
|
||||
|
||||
err = s.pool.QueryRow(s.ctx, query,
|
||||
encryptedEmail,
|
||||
emailHash,
|
||||
passwordHash,
|
||||
encryptedPhone,
|
||||
encryptedUserName,
|
||||
"Test Company",
|
||||
balance,
|
||||
"active",
|
||||
0,
|
||||
10,
|
||||
).Scan(&userID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
return email, password, userID
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) isInviteCodeActive(code int64) bool {
|
||||
var isActive bool
|
||||
err := s.pool.QueryRow(s.ctx,
|
||||
"SELECT is_active FROM invite_codes WHERE code = $1",
|
||||
code,
|
||||
).Scan(&isActive)
|
||||
if err != nil {
|
||||
return false
|
||||
}
|
||||
return isActive
|
||||
}
|
||||
|
||||
194
tests/transaction_rollback_test.go
Normal file
194
tests/transaction_rollback_test.go
Normal file
@@ -0,0 +1,194 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
|
||||
authpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth"
|
||||
requestpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/request"
|
||||
)
|
||||
|
||||
func (s *IntegrationSuite) TestTransaction_CreateTZ_InsufficientBalance_Rollback() {
|
||||
email, password, userID := s.createUniqueTestUser("insufficient_tz", 0.001)
|
||||
|
||||
loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{
|
||||
Email: email,
|
||||
Password: password,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: loginResp.AccessToken,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
initialBalance := s.getUserBalance(userID)
|
||||
s.T().Logf("Initial balance: %.4f", initialBalance)
|
||||
|
||||
_, err = s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{
|
||||
UserId: validateResp.UserId,
|
||||
RequestTxt: "Тест с недостаточным балансом",
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
s.T().Logf("CreateTZ failed as expected: %v", err)
|
||||
|
||||
finalBalance := s.getUserBalance(userID)
|
||||
s.T().Logf("Final balance after failed CreateTZ: %.4f", finalBalance)
|
||||
|
||||
s.GreaterOrEqual(finalBalance, 0.0,
|
||||
"Баланс не должен быть отрицательным после rollback")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) TestTransaction_ApproveTZ_InsufficientBalance_NoSuppliers() {
|
||||
email, password, userID := s.createUniqueTestUser("approve_insufficient", 1000.0)
|
||||
|
||||
loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{
|
||||
Email: email,
|
||||
Password: password,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: loginResp.AccessToken,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
createResp, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{
|
||||
UserId: validateResp.UserId,
|
||||
RequestTxt: "Тест approve с недостаточным балансом",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
requestID := createResp.RequestId
|
||||
|
||||
_, err = s.pool.Exec(s.ctx, "UPDATE users SET balance = 0.001 WHERE id = $1", userID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
suppliersBeforeApprove := s.getRequestSuppliersCount(requestID)
|
||||
s.T().Logf("Suppliers before ApproveTZ: %d", suppliersBeforeApprove)
|
||||
|
||||
_, err = s.requestClient.ApproveTZ(s.ctx, &requestpb.ApproveTZRequest{
|
||||
RequestId: requestID,
|
||||
FinalTz: "Утвержденное ТЗ",
|
||||
UserId: validateResp.UserId,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
s.T().Logf("ApproveTZ failed as expected: %v", err)
|
||||
|
||||
suppliersAfterApprove := s.getRequestSuppliersCount(requestID)
|
||||
s.T().Logf("Suppliers after failed ApproveTZ: %d", suppliersAfterApprove)
|
||||
|
||||
finalBalance := s.getUserBalance(userID)
|
||||
s.GreaterOrEqual(finalBalance, 0.0,
|
||||
"Баланс не должен быть отрицательным")
|
||||
}
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) TestTransaction_ConcurrentCreateTZ_BalanceAtomicity() {
|
||||
email, password, userID := s.createUniqueTestUser("concurrent_tz", 100.0)
|
||||
|
||||
loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{
|
||||
Email: email,
|
||||
Password: password,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: loginResp.AccessToken,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var successCount int32
|
||||
var errorCount int32
|
||||
goroutines := 10
|
||||
|
||||
startBarrier := make(chan struct{})
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func(idx int) {
|
||||
defer wg.Done()
|
||||
|
||||
<-startBarrier
|
||||
|
||||
_, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{
|
||||
UserId: validateResp.UserId,
|
||||
RequestTxt: "Параллельный тест CreateTZ",
|
||||
})
|
||||
|
||||
if err == nil {
|
||||
atomic.AddInt32(&successCount, 1)
|
||||
} else {
|
||||
atomic.AddInt32(&errorCount, 1)
|
||||
}
|
||||
}(i)
|
||||
}
|
||||
|
||||
close(startBarrier)
|
||||
wg.Wait()
|
||||
|
||||
s.T().Logf("Concurrent CreateTZ - Success: %d, Errors: %d", successCount, errorCount)
|
||||
|
||||
finalBalance := s.getUserBalance(userID)
|
||||
s.T().Logf("Final balance: %.4f", finalBalance)
|
||||
|
||||
s.GreaterOrEqual(finalBalance, 0.0,
|
||||
"Баланс не должен быть отрицательным после параллельных операций")
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) TestTransaction_TokenUsage_BalanceConsistency() {
|
||||
email, password, userID := s.createUniqueTestUser("token_consistency", 1000.0)
|
||||
|
||||
initialBalance := s.getUserBalance(userID)
|
||||
s.T().Logf("Initial balance: %.4f", initialBalance)
|
||||
|
||||
loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{
|
||||
Email: email,
|
||||
Password: password,
|
||||
Ip: "127.0.0.1",
|
||||
UserAgent: "test-agent",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{
|
||||
AccessToken: loginResp.AccessToken,
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
|
||||
createResp, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{
|
||||
UserId: validateResp.UserId,
|
||||
RequestTxt: "Тест consistency token_usage и balance",
|
||||
})
|
||||
s.Require().NoError(err)
|
||||
requestID := createResp.RequestId
|
||||
|
||||
tokenUsageCount := s.getTokenUsageCount(requestID)
|
||||
s.T().Logf("Token usage records for request: %d", tokenUsageCount)
|
||||
|
||||
finalBalance := s.getUserBalance(userID)
|
||||
balanceDelta := initialBalance - finalBalance
|
||||
s.T().Logf("Balance delta: %.4f", balanceDelta)
|
||||
|
||||
if tokenUsageCount > 0 {
|
||||
s.Greater(balanceDelta, 0.0,
|
||||
"Баланс должен уменьшиться при наличии token_usage записей")
|
||||
}
|
||||
|
||||
var totalTokenCost float64
|
||||
err = s.pool.QueryRow(s.ctx,
|
||||
"SELECT COALESCE(SUM(token_cost), 0) FROM request_token_usage WHERE request_id = $1::uuid",
|
||||
requestID,
|
||||
).Scan(&totalTokenCost)
|
||||
s.NoError(err)
|
||||
|
||||
s.T().Logf("Total token cost from DB: %.4f, Balance delta: %.4f", totalTokenCost, balanceDelta)
|
||||
}
|
||||
181
tests/worker_concurrent_test.go
Normal file
181
tests/worker_concurrent_test.go
Normal file
@@ -0,0 +1,181 @@
|
||||
package tests
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/google/uuid"
|
||||
)
|
||||
|
||||
func (s *IntegrationSuite) TestWorkerConcurrent_SessionCleanup_MassExpired() {
|
||||
_, _, userID := s.createUniqueTestUser("session_cleanup", 100.0)
|
||||
|
||||
expiredTime := time.Now().Add(-24 * time.Hour)
|
||||
validTime := time.Now().Add(24 * time.Hour)
|
||||
|
||||
expiredCount := 100
|
||||
validCount := 50
|
||||
|
||||
for i := 0; i < expiredCount; i++ {
|
||||
_, err := s.pool.Exec(s.ctx, `
|
||||
INSERT INTO sessions (user_id, access_token, refresh_token, ip, user_agent, expires_at)
|
||||
VALUES ($1, $2, $3, '127.0.0.1', 'test-agent', $4)
|
||||
`, userID,
|
||||
fmt.Sprintf("expired_access_%d_%s", i, uuid.New().String()),
|
||||
fmt.Sprintf("expired_refresh_%d_%s", i, uuid.New().String()),
|
||||
expiredTime,
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
for i := 0; i < validCount; i++ {
|
||||
_, err := s.pool.Exec(s.ctx, `
|
||||
INSERT INTO sessions (user_id, access_token, refresh_token, ip, user_agent, expires_at)
|
||||
VALUES ($1, $2, $3, '127.0.0.1', 'test-agent', $4)
|
||||
`, userID,
|
||||
fmt.Sprintf("valid_access_%d_%s", i, uuid.New().String()),
|
||||
fmt.Sprintf("valid_refresh_%d_%s", i, uuid.New().String()),
|
||||
validTime,
|
||||
)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
var totalBefore int
|
||||
err := s.pool.QueryRow(s.ctx,
|
||||
"SELECT COUNT(*) FROM sessions WHERE user_id = $1", userID,
|
||||
).Scan(&totalBefore)
|
||||
s.Require().NoError(err)
|
||||
s.T().Logf("Sessions before cleanup: %d", totalBefore)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var totalDeleted int32
|
||||
goroutines := 10
|
||||
|
||||
startBarrier := make(chan struct{})
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-startBarrier
|
||||
|
||||
result, err := s.pool.Exec(s.ctx, `
|
||||
DELETE FROM sessions
|
||||
WHERE expires_at < now()
|
||||
OR (revoked_at IS NOT NULL AND revoked_at < now() - interval '30 days')
|
||||
`)
|
||||
|
||||
if err == nil {
|
||||
atomic.AddInt32(&totalDeleted, int32(result.RowsAffected()))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
close(startBarrier)
|
||||
wg.Wait()
|
||||
|
||||
s.T().Logf("Total deleted by concurrent cleanup: %d", totalDeleted)
|
||||
|
||||
var validRemaining int
|
||||
err = s.pool.QueryRow(s.ctx,
|
||||
"SELECT COUNT(*) FROM sessions WHERE user_id = $1 AND expires_at > now()", userID,
|
||||
).Scan(&validRemaining)
|
||||
s.NoError(err)
|
||||
|
||||
s.T().Logf("Valid sessions remaining: %d (expected: %d)", validRemaining, validCount)
|
||||
|
||||
s.Equal(validCount, validRemaining,
|
||||
"Все валидные сессии должны остаться после cleanup")
|
||||
|
||||
s.GreaterOrEqual(int(totalDeleted), expiredCount,
|
||||
"Все истекшие сессии должны быть удалены")
|
||||
}
|
||||
|
||||
func (s *IntegrationSuite) TestWorkerConcurrent_InviteCleanup_MassExpired() {
|
||||
_, _, userID := s.createUniqueTestUser("invite_cleanup", 100.0)
|
||||
|
||||
_, err := s.pool.Exec(s.ctx, "UPDATE users SET invites_limit = 200 WHERE id = $1", userID)
|
||||
s.Require().NoError(err)
|
||||
|
||||
expiredTime := time.Now().Add(-24 * time.Hour)
|
||||
validTime := time.Now().Add(24 * time.Hour)
|
||||
|
||||
expiredCount := 100
|
||||
validCount := 50
|
||||
|
||||
for i := 0; i < expiredCount; i++ {
|
||||
code := int64(30000000 + i)
|
||||
_, err := s.pool.Exec(s.ctx, `
|
||||
INSERT INTO invite_codes (user_id, code, can_be_used_count, expires_at, is_active)
|
||||
VALUES ($1, $2, 5, $3, true)
|
||||
`, userID, code, expiredTime)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
for i := 0; i < validCount; i++ {
|
||||
code := int64(40000000 + i)
|
||||
_, err := s.pool.Exec(s.ctx, `
|
||||
INSERT INTO invite_codes (user_id, code, can_be_used_count, expires_at, is_active)
|
||||
VALUES ($1, $2, 5, $3, true)
|
||||
`, userID, code, validTime)
|
||||
s.Require().NoError(err)
|
||||
}
|
||||
|
||||
var activeBefore int
|
||||
err = s.pool.QueryRow(s.ctx,
|
||||
"SELECT COUNT(*) FROM invite_codes WHERE user_id = $1 AND is_active = true", userID,
|
||||
).Scan(&activeBefore)
|
||||
s.Require().NoError(err)
|
||||
s.T().Logf("Active invites before cleanup: %d", activeBefore)
|
||||
|
||||
var wg sync.WaitGroup
|
||||
var totalDeactivated int32
|
||||
goroutines := 10
|
||||
|
||||
startBarrier := make(chan struct{})
|
||||
|
||||
for i := 0; i < goroutines; i++ {
|
||||
wg.Add(1)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
<-startBarrier
|
||||
|
||||
result, err := s.pool.Exec(s.ctx, `
|
||||
UPDATE invite_codes
|
||||
SET is_active = false
|
||||
WHERE expires_at < now() AND is_active = true
|
||||
`)
|
||||
|
||||
if err == nil {
|
||||
atomic.AddInt32(&totalDeactivated, int32(result.RowsAffected()))
|
||||
}
|
||||
}()
|
||||
}
|
||||
|
||||
close(startBarrier)
|
||||
wg.Wait()
|
||||
|
||||
s.T().Logf("Total deactivated by concurrent cleanup: %d", totalDeactivated)
|
||||
|
||||
var activeRemaining int
|
||||
err = s.pool.QueryRow(s.ctx,
|
||||
"SELECT COUNT(*) FROM invite_codes WHERE user_id = $1 AND is_active = true", userID,
|
||||
).Scan(&activeRemaining)
|
||||
s.NoError(err)
|
||||
|
||||
s.T().Logf("Active invites remaining: %d (expected: %d)", activeRemaining, validCount)
|
||||
|
||||
s.Equal(validCount, activeRemaining,
|
||||
"Все валидные инвайты должны остаться активными после cleanup")
|
||||
|
||||
var expiredStillActive int
|
||||
err = s.pool.QueryRow(s.ctx,
|
||||
"SELECT COUNT(*) FROM invite_codes WHERE user_id = $1 AND expires_at < now() AND is_active = true", userID,
|
||||
).Scan(&expiredStillActive)
|
||||
s.NoError(err)
|
||||
|
||||
s.Equal(0, expiredStillActive,
|
||||
"Не должно остаться активных истекших инвайтов")
|
||||
}
|
||||
Reference in New Issue
Block a user