diff --git a/.gitea/workflows/deploy.yml b/.gitea/workflows/deploy.yml index c05fa83..a133398 100644 --- a/.gitea/workflows/deploy.yml +++ b/.gitea/workflows/deploy.yml @@ -33,7 +33,10 @@ jobs: echo "${{ secrets.DEPLOY_KEY }}" > ~/.ssh/id_rsa chmod 600 ~/.ssh/id_rsa ssh-keyscan -H 185.185.142.203 >> ~/.ssh/known_hosts 2>/dev/null || true - ssh -i ~/.ssh/id_rsa root@185.185.142.203 'cd /home/n8n && /usr/bin/docker compose up -d smart-search-backend' + ssh -i ~/.ssh/id_rsa root@185.185.142.203 ' + docker rm -f smart-search-backend 2>/dev/null || true + cd /home/n8n && /usr/bin/docker compose up -d smart-search-backend + ' - name: Done run: echo "✅ Backend deployed!" diff --git a/.gitignore b/.gitignore index 3272aa1..3b4f725 100644 --- a/.gitignore +++ b/.gitignore @@ -35,4 +35,8 @@ config/config.yaml # Build artifacts /bin/ /dist/ -.gitea \ No newline at end of file +.gitea +.gitea/workflows + +# Coverage +coverage/ \ No newline at end of file diff --git a/Makefile b/Makefile index 8c4ece1..271ea5a 100644 --- a/Makefile +++ b/Makefile @@ -1,4 +1,4 @@ -.PHONY: help build run migrate-up migrate-down migrate-create lint test test-integration proto clean +.PHONY: help build run migrate-up migrate-down migrate-create lint test test-integration test-coverage proto clean help: @echo "Available commands:" @@ -12,6 +12,7 @@ help: @echo " make generate-mock - Generate mocks for all interfaces (minimock)" @echo " make test - Run unit tests" @echo " make test-integration - Run integration tests with testcontainers" + @echo " make test-coverage - Run tests with coverage report" @echo " make clean - Clean build artifacts" build: @@ -87,5 +88,13 @@ test-integration: @echo "This may take several minutes..." go test -v -timeout=10m ./tests/... +test-coverage: + @echo "Running tests with coverage..." + @mkdir -p coverage + go test -timeout=10m -coverprofile=coverage/coverage.out -covermode=atomic ./... + go tool cover -func=coverage/coverage.out + go tool cover -html=coverage/coverage.out -o coverage/coverage.html + @echo "Coverage report: coverage/coverage.html" + # Default DB URL for local development -DB_URL ?= postgres://postgres:password@localhost:5432/b2b_search?sslmode=disable +DB_URL ?= postgres://postgres:password@localhost:5432/b2b_data?sslmode=disable diff --git a/api/proto/invite/invite.proto b/api/proto/invite/invite.proto index 846d0b7..3685364 100644 --- a/api/proto/invite/invite.proto +++ b/api/proto/invite/invite.proto @@ -22,15 +22,12 @@ message GenerateResponse { } message GetInfoRequest { - string code = 1; + int64 user_id = 1; } message GetInfoResponse { string code = 1; - int64 user_id = 2; - int32 can_be_used_count = 3; - int32 used_count = 4; - google.protobuf.Timestamp expires_at = 5; - bool is_active = 6; - google.protobuf.Timestamp created_at = 7; + int32 can_be_used_count = 2; + google.protobuf.Timestamp expires_at = 3; + google.protobuf.Timestamp created_at = 4; } diff --git a/api/proto/request/request.proto b/api/proto/request/request.proto index ed359d1..f2ebfef 100644 --- a/api/proto/request/request.proto +++ b/api/proto/request/request.proto @@ -30,8 +30,17 @@ message ApproveTZRequest { } message ApproveTZResponse { - bool success = 1; - string mailing_status = 2; + string request_id = 1; + repeated Supplier suppliers = 2; +} + +message Supplier { + string id = 1; + string name = 2; + string email = 3; + string phone = 4; + string address = 5; + string url = 6; } message GetMailingListRequest { @@ -48,7 +57,7 @@ message GetMailingListByIDRequest { } message GetMailingListByIDResponse { - MailingItem item = 1; + MailingDetail detail = 1; } message MailingItem { @@ -59,3 +68,10 @@ message MailingItem { google.protobuf.Timestamp created_at = 5; int32 suppliers_found = 6; } + +message MailingDetail { + string request_id = 1; + string title = 2; + string mail_text = 3; + repeated Supplier suppliers = 4; +} diff --git a/api/proto/user/user.proto b/api/proto/user/user.proto index d884374..74fbf19 100644 --- a/api/proto/user/user.proto +++ b/api/proto/user/user.proto @@ -34,18 +34,22 @@ message GetStatisticsRequest { } message GetStatisticsResponse { - int32 total_requests = 1; - int32 successful_requests = 2; - int32 failed_requests = 3; - double total_spent = 4; + string suppliers_count = 1; + string requests_count = 2; + string created_tz = 3; } message GetBalanceStatisticsRequest { int64 user_id = 1; } -message GetBalanceStatisticsResponse { - double balance = 1; - int32 total_requests = 2; - double total_spent = 3; +message WriteOffHistoryItem { + string operation_id = 1; + string data = 2; + double amount = 3; +} + +message GetBalanceStatisticsResponse { + double average_cost = 1; + repeated WriteOffHistoryItem write_off_history = 2; } diff --git a/cmd/server/main.go b/cmd/server/main.go index 90d88a0..8c2cdbe 100644 --- a/cmd/server/main.go +++ b/cmd/server/main.go @@ -2,13 +2,14 @@ package main import ( "context" - "log" + "os" "github.com/jackc/pgx/v5/pgxpool" _ "github.com/jackc/pgx/v5/stdlib" rkboot "github.com/rookie-ninja/rk-boot/v2" rkentry "github.com/rookie-ninja/rk-entry/v2/entry" rkgrpc "github.com/rookie-ninja/rk-grpc/v2/boot" + "go.uber.org/zap" "google.golang.org/grpc" "git.techease.ru/Smart-search/smart-search-back/internal/config" @@ -19,50 +20,51 @@ import ( ) func main() { - cfg, err := config.Load("config/config.yaml") - if err != nil { - log.Fatalf("Failed to load config: %v", err) - } - - ctx := context.Background() - - if err := database.RunMigrations(cfg.DatabaseURL()); err != nil { - log.Fatalf("Failed to run migrations: %v", err) - } - - pool, err := pgxpool.New(ctx, cfg.DatabaseURL()) - if err != nil { - log.Fatalf("Failed to connect to database: %v", err) - } - defer pool.Close() - - if err := pool.Ping(ctx); err != nil { - log.Fatalf("Failed to ping database: %v", err) - } - - log.Println("Successfully connected to database") - boot := rkboot.NewBoot(rkboot.WithBootConfigPath("config/boot.yaml", nil)) - grpcEntry := rkgrpc.GetGrpcEntry("smart-search-service") - if grpcEntry == nil { - log.Fatal("Failed to get gRPC entry from rk-boot") - } - - loggerEntry := rkentry.GlobalAppCtx.GetLoggerEntry("smart-search-service") + loggerEntry := rkentry.GlobalAppCtx.GetLoggerEntry("smart-search-logger") if loggerEntry == nil { loggerEntry = rkentry.GlobalAppCtx.GetLoggerEntryDefault() } logger := loggerEntry.Logger + cfg, err := config.Load("config/config.yaml") + if err != nil { + logger.Fatal("Failed to load config", zap.Error(err)) + } + + ctx := context.Background() + + if err := database.RunMigrations(cfg.DatabaseURL(), logger); err != nil { + logger.Fatal("Failed to run migrations", zap.Error(err)) + } + + pool, err := pgxpool.New(ctx, cfg.DatabaseURL()) + if err != nil { + logger.Fatal("Failed to connect to database", zap.Error(err)) + } + defer pool.Close() + + if err := pool.Ping(ctx); err != nil { + logger.Fatal("Failed to ping database", zap.Error(err)) + } + + logger.Info("Successfully connected to database") + + grpcEntry := rkgrpc.GetGrpcEntry("smart-search-service") + if grpcEntry == nil { + logger.Fatal("Failed to get gRPC entry from rk-boot") + os.Exit(1) + } + sessionRepo := repository.NewSessionRepository(pool) inviteRepo := repository.NewInviteRepository(pool) - sessionCleaner := worker.NewSessionCleaner(ctx, sessionRepo) + sessionCleaner := worker.NewSessionCleaner(ctx, sessionRepo, logger) sessionCleaner.Start() defer sessionCleaner.Stop() - inviteCleaner := worker.NewInviteCleaner(ctx, inviteRepo) + inviteCleaner := worker.NewInviteCleaner(ctx, inviteRepo, logger) inviteCleaner.Start() defer inviteCleaner.Stop() @@ -81,9 +83,9 @@ func main() { boot.Bootstrap(ctx) - log.Println("gRPC server started via rk-boot on port 9091") + logger.Info("gRPC server started via rk-boot") boot.WaitForShutdownSig(ctx) - log.Println("Server stopped gracefully") + logger.Info("Server stopped gracefully") } diff --git a/config/boot.yaml b/config/boot.yaml index da5e9e0..0576378 100644 --- a/config/boot.yaml +++ b/config/boot.yaml @@ -1,34 +1,27 @@ --- logger: - name: smart-search-logger - description: "Application logger for smart-search service" default: true zap: - level: error - development: false + level: info encoding: console outputPaths: ["stdout"] errorOutputPaths: ["stderr"] - disableCaller: false - disableStacktrace: false grpc: - name: smart-search-service - port: 9091 + port: 9092 enabled: true enableReflection: true enableRkGwOption: true loggerEntry: smart-search-logger - eventEntry: smart-search-logger middleware: logging: - enabled: true - loggerEncoding: "console" - loggerOutputPaths: ["stdout"] + enabled: false meta: - enabled: true + enabled: false trace: - enabled: true + enabled: false prometheus: enabled: true auth: diff --git a/config/config.yaml.example b/config/config.yaml.example index d130da5..37c149f 100644 --- a/config/config.yaml.example +++ b/config/config.yaml.example @@ -16,9 +16,5 @@ security: jwt_secret: ${JWT_SECRET:xM8KhJVkk28cIJeBo0306O2e6Ifni6tNVlcCMxDFAEc=} crypto_secret: ${CRYPTO_SECRET:xM8KhJVkk28cIJeBo0306O2e6Ifni6tNVlcCMxDFAEc=} -grpc: - port: ${GRPC_PORT:9091} - max_connections: ${GRPC_MAX_CONNS:100} - logging: level: ${LOG_LEVEL:info} diff --git a/go.mod b/go.mod index 3178d64..a5df3ac 100644 --- a/go.mod +++ b/go.mod @@ -17,6 +17,7 @@ require ( github.com/testcontainers/testcontainers-go/modules/postgres v0.40.0 github.com/xuri/excelize/v2 v2.10.0 go.uber.org/zap v1.27.1 + google.golang.org/genproto/googleapis/rpc v0.0.0-20260114163908-3f89685c29c3 google.golang.org/grpc v1.78.0 google.golang.org/protobuf v1.36.11 gopkg.in/yaml.v3 v3.0.1 @@ -95,7 +96,6 @@ require ( github.com/spf13/cast v1.10.0 // indirect github.com/spf13/pflag v1.0.10 // indirect github.com/spf13/viper v1.21.0 // indirect - github.com/stretchr/objx v0.5.2 // indirect github.com/subosito/gotenv v1.6.0 // indirect github.com/tiendc/go-deepcopy v1.7.1 // indirect github.com/tklauser/go-sysconf v0.3.12 // indirect @@ -126,7 +126,6 @@ require ( golang.org/x/sys v0.40.0 // indirect golang.org/x/text v0.33.0 // indirect google.golang.org/genproto/googleapis/api v0.0.0-20251029180050-ab9386a59fda // indirect - google.golang.org/genproto/googleapis/rpc v0.0.0-20260114163908-3f89685c29c3 // indirect gopkg.in/natefinch/lumberjack.v2 v2.2.1 // indirect gopkg.in/yaml.v2 v2.4.0 // indirect nhooyr.io/websocket v1.8.6 // indirect diff --git a/internal/ai/openai.go b/internal/ai/openai.go index 7278e29..622f2c5 100644 --- a/internal/ai/openai.go +++ b/internal/ai/openai.go @@ -41,9 +41,17 @@ type tzResponse struct { } func NewOpenAIClient(apiKey string) *OpenAIClient { + transport := &http.Transport{ + Proxy: http.ProxyFromEnvironment, + } + + client := &http.Client{ + Transport: transport, + } + return &OpenAIClient{ apiKey: apiKey, - client: &http.Client{}, + client: client, } } diff --git a/internal/ai/perplexity.go b/internal/ai/perplexity.go index 4b778e4..7a0c3e6 100644 --- a/internal/ai/perplexity.go +++ b/internal/ai/perplexity.go @@ -45,7 +45,7 @@ type supplierData struct { CompanyName string `json:"company_name"` Email string `json:"email"` Phone string `json:"phone"` - Address string `json:"adress"` + Address string `json:"address"` URL string `json:"url"` } @@ -123,7 +123,7 @@ func (c *PerplexityClient) FindSuppliers(tzText string) ([]*model.Supplier, int, - Сортируй по релевантности`, tzText) reqBody := perplexityRequest{ - Model: "llama-3.1-sonar-large-128k-online", + Model: "sonar", Messages: []perplexityMessage{ { Role: "user", diff --git a/internal/config/config.go b/internal/config/config.go index 9f9b061..ab97a3c 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -12,7 +12,6 @@ type Config struct { Database DatabaseConfig `yaml:"database"` AI AIConfig `yaml:"ai"` Security SecurityConfig `yaml:"security"` - GRPC GRPCConfig `yaml:"grpc"` Logging LoggingConfig `yaml:"logging"` } @@ -37,11 +36,6 @@ type SecurityConfig struct { CryptoSecret string `yaml:"crypto_secret"` } -type GRPCConfig struct { - Port int `yaml:"port"` - MaxConnections int `yaml:"max_connections"` -} - type LoggingConfig struct { Level string `yaml:"level"` } diff --git a/internal/database/migrations.go b/internal/database/migrations.go index 97a8d6c..86129d9 100644 --- a/internal/database/migrations.go +++ b/internal/database/migrations.go @@ -7,13 +7,14 @@ import ( _ "github.com/jackc/pgx/v5/stdlib" "github.com/pressly/goose/v3" + "go.uber.org/zap" ) -func RunMigrations(databaseURL string) error { - return RunMigrationsFromPath(databaseURL, "migrations") +func RunMigrations(databaseURL string, logger *zap.Logger) error { + return RunMigrationsFromPath(databaseURL, "migrations", logger) } -func RunMigrationsFromPath(databaseURL, migrationsDir string) error { +func RunMigrationsFromPath(databaseURL, migrationsDir string, logger *zap.Logger) error { db, err := sql.Open("pgx", databaseURL) if err != nil { return fmt.Errorf("failed to open database connection for migrations: %w", err) @@ -28,6 +29,8 @@ func RunMigrationsFromPath(databaseURL, migrationsDir string) error { return fmt.Errorf("failed to set goose dialect: %w", err) } + goose.SetLogger(&gooseLogger{logger: logger}) + absPath, err := filepath.Abs(migrationsDir) if err != nil { return fmt.Errorf("failed to resolve migrations path: %w", err) @@ -39,3 +42,15 @@ func RunMigrationsFromPath(databaseURL, migrationsDir string) error { return nil } + +type gooseLogger struct { + logger *zap.Logger +} + +func (l *gooseLogger) Fatalf(format string, v ...interface{}) { + l.logger.Fatal(fmt.Sprintf(format, v...)) +} + +func (l *gooseLogger) Printf(format string, v ...interface{}) { + l.logger.Info(fmt.Sprintf(format, v...)) +} diff --git a/internal/grpc/invite_handler.go b/internal/grpc/invite_handler.go index df490d4..6d83ddf 100644 --- a/internal/grpc/invite_handler.go +++ b/internal/grpc/invite_handler.go @@ -23,23 +23,15 @@ func (h *InviteHandler) Generate(ctx context.Context, req *pb.GenerateRequest) ( } func (h *InviteHandler) GetInfo(ctx context.Context, req *pb.GetInfoRequest) (*pb.GetInfoResponse, error) { - code, err := strconv.ParseInt(req.Code, 10, 64) - if err != nil { - return nil, errors.ToGRPCError(err, h.logger, "InviteService.GetInfo") - } - - invite, err := h.inviteService.GetInfo(ctx, code) + invite, err := h.inviteService.GetInfo(ctx, int(req.UserId)) if err != nil { return nil, errors.ToGRPCError(err, h.logger, "InviteService.GetInfo") } return &pb.GetInfoResponse{ Code: strconv.FormatInt(invite.Code, 10), - UserId: int64(invite.UserID), CanBeUsedCount: int32(invite.CanBeUsedCount), - UsedCount: int32(invite.UsedCount), ExpiresAt: timestamppb.New(invite.ExpiresAt), - IsActive: invite.IsActive, CreatedAt: timestamppb.New(invite.CreatedAt), }, nil } diff --git a/internal/grpc/request_handler.go b/internal/grpc/request_handler.go index 352989c..a3ea20a 100644 --- a/internal/grpc/request_handler.go +++ b/internal/grpc/request_handler.go @@ -2,7 +2,7 @@ package grpc import ( "context" - "time" + "strconv" "git.techease.ru/Smart-search/smart-search-back/pkg/errors" pb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/request" @@ -11,12 +11,7 @@ import ( ) func (h *RequestHandler) CreateTZ(ctx context.Context, req *pb.CreateTZRequest) (*pb.CreateTZResponse, error) { - requestTxt := req.RequestTxt - if len(req.FileData) > 0 { - requestTxt += "\n[File: " + req.FileName + "]" - } - - requestID, tzText, err := h.requestService.CreateTZ(ctx, int(req.UserId), requestTxt) + requestID, tzText, err := h.requestService.CreateTZ(ctx, int(req.UserId), req.RequestTxt, req.FileData, req.FileName) if err != nil { return nil, errors.ToGRPCError(err, h.logger, "RequestService.CreateTZ") } @@ -33,14 +28,26 @@ func (h *RequestHandler) ApproveTZ(ctx context.Context, req *pb.ApproveTZRequest return nil, errors.ToGRPCError(err, h.logger, "RequestService.ApproveTZ") } - _, err = h.requestService.ApproveTZ(ctx, requestID, req.FinalTz, int(req.UserId)) + suppliers, err := h.requestService.ApproveTZ(ctx, requestID, req.FinalTz, int(req.UserId)) if err != nil { return nil, errors.ToGRPCError(err, h.logger, "RequestService.ApproveTZ") } + pbSuppliers := make([]*pb.Supplier, 0, len(suppliers)) + for _, s := range suppliers { + pbSuppliers = append(pbSuppliers, &pb.Supplier{ + Id: strconv.Itoa(s.ID), + Name: s.Name, + Email: s.Email, + Phone: s.Phone, + Address: s.Address, + Url: s.URL, + }) + } + return &pb.ApproveTZResponse{ - Success: true, - MailingStatus: "sent", + RequestId: req.RequestId, + Suppliers: pbSuppliers, }, nil } @@ -73,19 +80,28 @@ func (h *RequestHandler) GetMailingListByID(ctx context.Context, req *pb.GetMail return nil, errors.ToGRPCError(err, h.logger, "RequestService.GetMailingListByID") } - detail, err := h.requestService.GetMailingListByID(ctx, requestID) + detail, err := h.requestService.GetMailingListByID(ctx, requestID, int(req.UserId)) if err != nil { return nil, errors.ToGRPCError(err, h.logger, "RequestService.GetMailingListByID") } + suppliers := make([]*pb.Supplier, 0, len(detail.Suppliers)) + for _, s := range detail.Suppliers { + suppliers = append(suppliers, &pb.Supplier{ + Id: strconv.Itoa(s.CompanyID), + Name: s.CompanyName, + Email: s.Email, + Phone: s.Phone, + Url: s.URL, + }) + } + return &pb.GetMailingListByIDResponse{ - Item: &pb.MailingItem{ - RequestId: detail.RequestID.String(), - RequestTxt: detail.Title, - FinalTz: detail.MailText, - MailingStatus: "sent", - CreatedAt: timestamppb.New(time.Now()), - SuppliersFound: int32(len(detail.Suppliers)), + Detail: &pb.MailingDetail{ + RequestId: detail.RequestID.String(), + Title: detail.Title, + MailText: detail.MailText, + Suppliers: suppliers, }, }, nil } diff --git a/internal/grpc/server.go b/internal/grpc/server.go index 5333000..200426e 100644 --- a/internal/grpc/server.go +++ b/internal/grpc/server.go @@ -59,10 +59,10 @@ func NewHandlers(pool *pgxpool.Pool, jwtSecret, cryptoSecret, openAIKey, perplex perplexityClient := ai.NewPerplexityClient(perplexityKey) authService := service.NewAuthService(userRepo, sessionRepo, inviteRepo, txManager, jwtSecret, cryptoSecret) - userService := service.NewUserService(userRepo, requestRepo, cryptoSecret) + userService := service.NewUserService(userRepo, requestRepo, tokenUsageRepo, cryptoSecret) inviteService := service.NewInviteService(inviteRepo, userRepo, txManager) requestService := service.NewRequestService(requestRepo, supplierRepo, tokenUsageRepo, userRepo, openAIClient, perplexityClient, txManager) - supplierService := service.NewSupplierService(supplierRepo) + supplierService := service.NewSupplierService(supplierRepo, requestRepo) return &AuthHandler{authService: authService, logger: logger}, &UserHandler{userService: userService, logger: logger}, diff --git a/internal/grpc/supplier_handler.go b/internal/grpc/supplier_handler.go index 83b8652..8f6be2b 100644 --- a/internal/grpc/supplier_handler.go +++ b/internal/grpc/supplier_handler.go @@ -14,7 +14,7 @@ func (h *SupplierHandler) ExportExcel(ctx context.Context, req *pb.ExportExcelRe return nil, errors.ToGRPCError(err, h.logger, "SupplierService.ExportExcel") } - fileData, err := h.supplierService.ExportExcel(ctx, requestID) + fileData, err := h.supplierService.ExportExcel(ctx, requestID, int(req.UserId)) if err != nil { return nil, errors.ToGRPCError(err, h.logger, "SupplierService.ExportExcel") } diff --git a/internal/grpc/user_handler.go b/internal/grpc/user_handler.go index 81862ec..e4fed82 100644 --- a/internal/grpc/user_handler.go +++ b/internal/grpc/user_handler.go @@ -2,6 +2,7 @@ package grpc import ( "context" + "strconv" "git.techease.ru/Smart-search/smart-search-back/pkg/errors" pb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/user" @@ -40,27 +41,29 @@ func (h *UserHandler) GetStatistics(ctx context.Context, req *pb.GetStatisticsRe } return &pb.GetStatisticsResponse{ - TotalRequests: int32(stats.RequestsCount), - SuccessfulRequests: int32(stats.SuppliersCount), - FailedRequests: 0, - TotalSpent: 0, + SuppliersCount: strconv.Itoa(stats.SuppliersCount), + RequestsCount: strconv.Itoa(stats.RequestsCount), + CreatedTz: strconv.Itoa(stats.CreatedTZ), }, nil } func (h *UserHandler) GetBalanceStatistics(ctx context.Context, req *pb.GetBalanceStatisticsRequest) (*pb.GetBalanceStatisticsResponse, error) { - balance, err := h.userService.GetBalance(ctx, int(req.UserId)) + stats, err := h.userService.GetBalanceStatistics(ctx, int(req.UserId)) if err != nil { return nil, errors.ToGRPCError(err, h.logger, "UserService.GetBalanceStatistics") } - stats, err := h.userService.GetStatistics(ctx, int(req.UserId)) - if err != nil { - return nil, errors.ToGRPCError(err, h.logger, "UserService.GetBalanceStatistics") + history := make([]*pb.WriteOffHistoryItem, 0, len(stats.WriteOffHistory)) + for _, item := range stats.WriteOffHistory { + history = append(history, &pb.WriteOffHistoryItem{ + OperationId: item.OperationID, + Data: item.Data, + Amount: item.Amount, + }) } return &pb.GetBalanceStatisticsResponse{ - Balance: balance, - TotalRequests: int32(stats.RequestsCount), - TotalSpent: 0, + AverageCost: stats.AverageCost, + WriteOffHistory: history, }, nil } diff --git a/internal/mocks/invite_repository_mock.go b/internal/mocks/invite_repository_mock.go index 268d4e5..993e8d4 100644 --- a/internal/mocks/invite_repository_mock.go +++ b/internal/mocks/invite_repository_mock.go @@ -55,6 +55,13 @@ type InviteRepositoryMock struct { beforeFindActiveByCodeCounter uint64 FindActiveByCodeMock mInviteRepositoryMockFindActiveByCode + funcFindActiveByUserID func(ctx context.Context, userID int) (ip1 *model.InviteCode, err error) + funcFindActiveByUserIDOrigin string + inspectFuncFindActiveByUserID func(ctx context.Context, userID int) + afterFindActiveByUserIDCounter uint64 + beforeFindActiveByUserIDCounter uint64 + FindActiveByUserIDMock mInviteRepositoryMockFindActiveByUserID + funcFindByCode func(ctx context.Context, code int64) (ip1 *model.InviteCode, err error) funcFindByCodeOrigin string inspectFuncFindByCode func(ctx context.Context, code int64) @@ -68,13 +75,6 @@ type InviteRepositoryMock struct { afterGetUserInvitesCounter uint64 beforeGetUserInvitesCounter uint64 GetUserInvitesMock mInviteRepositoryMockGetUserInvites - - funcIncrementUsedCount func(ctx context.Context, code int64) (err error) - funcIncrementUsedCountOrigin string - inspectFuncIncrementUsedCount func(ctx context.Context, code int64) - afterIncrementUsedCountCounter uint64 - beforeIncrementUsedCountCounter uint64 - IncrementUsedCountMock mInviteRepositoryMockIncrementUsedCount } // NewInviteRepositoryMock returns a mock for mm_repository.InviteRepository @@ -100,15 +100,15 @@ func NewInviteRepositoryMock(t minimock.Tester) *InviteRepositoryMock { m.FindActiveByCodeMock = mInviteRepositoryMockFindActiveByCode{mock: m} m.FindActiveByCodeMock.callArgs = []*InviteRepositoryMockFindActiveByCodeParams{} + m.FindActiveByUserIDMock = mInviteRepositoryMockFindActiveByUserID{mock: m} + m.FindActiveByUserIDMock.callArgs = []*InviteRepositoryMockFindActiveByUserIDParams{} + m.FindByCodeMock = mInviteRepositoryMockFindByCode{mock: m} m.FindByCodeMock.callArgs = []*InviteRepositoryMockFindByCodeParams{} m.GetUserInvitesMock = mInviteRepositoryMockGetUserInvites{mock: m} m.GetUserInvitesMock.callArgs = []*InviteRepositoryMockGetUserInvitesParams{} - m.IncrementUsedCountMock = mInviteRepositoryMockIncrementUsedCount{mock: m} - m.IncrementUsedCountMock.callArgs = []*InviteRepositoryMockIncrementUsedCountParams{} - t.Cleanup(m.MinimockFinish) return m @@ -1857,6 +1857,349 @@ func (m *InviteRepositoryMock) MinimockFindActiveByCodeInspect() { } } +type mInviteRepositoryMockFindActiveByUserID struct { + optional bool + mock *InviteRepositoryMock + defaultExpectation *InviteRepositoryMockFindActiveByUserIDExpectation + expectations []*InviteRepositoryMockFindActiveByUserIDExpectation + + callArgs []*InviteRepositoryMockFindActiveByUserIDParams + mutex sync.RWMutex + + expectedInvocations uint64 + expectedInvocationsOrigin string +} + +// InviteRepositoryMockFindActiveByUserIDExpectation specifies expectation struct of the InviteRepository.FindActiveByUserID +type InviteRepositoryMockFindActiveByUserIDExpectation struct { + mock *InviteRepositoryMock + params *InviteRepositoryMockFindActiveByUserIDParams + paramPtrs *InviteRepositoryMockFindActiveByUserIDParamPtrs + expectationOrigins InviteRepositoryMockFindActiveByUserIDExpectationOrigins + results *InviteRepositoryMockFindActiveByUserIDResults + returnOrigin string + Counter uint64 +} + +// InviteRepositoryMockFindActiveByUserIDParams contains parameters of the InviteRepository.FindActiveByUserID +type InviteRepositoryMockFindActiveByUserIDParams struct { + ctx context.Context + userID int +} + +// InviteRepositoryMockFindActiveByUserIDParamPtrs contains pointers to parameters of the InviteRepository.FindActiveByUserID +type InviteRepositoryMockFindActiveByUserIDParamPtrs struct { + ctx *context.Context + userID *int +} + +// InviteRepositoryMockFindActiveByUserIDResults contains results of the InviteRepository.FindActiveByUserID +type InviteRepositoryMockFindActiveByUserIDResults struct { + ip1 *model.InviteCode + err error +} + +// InviteRepositoryMockFindActiveByUserIDOrigins contains origins of expectations of the InviteRepository.FindActiveByUserID +type InviteRepositoryMockFindActiveByUserIDExpectationOrigins struct { + origin string + originCtx string + originUserID string +} + +// Marks this method to be optional. The default behavior of any method with Return() is '1 or more', meaning +// the test will fail minimock's automatic final call check if the mocked method was not called at least once. +// Optional() makes method check to work in '0 or more' mode. +// It is NOT RECOMMENDED to use this option unless you really need it, as default behaviour helps to +// catch the problems when the expected method call is totally skipped during test run. +func (mmFindActiveByUserID *mInviteRepositoryMockFindActiveByUserID) Optional() *mInviteRepositoryMockFindActiveByUserID { + mmFindActiveByUserID.optional = true + return mmFindActiveByUserID +} + +// Expect sets up expected params for InviteRepository.FindActiveByUserID +func (mmFindActiveByUserID *mInviteRepositoryMockFindActiveByUserID) Expect(ctx context.Context, userID int) *mInviteRepositoryMockFindActiveByUserID { + if mmFindActiveByUserID.mock.funcFindActiveByUserID != nil { + mmFindActiveByUserID.mock.t.Fatalf("InviteRepositoryMock.FindActiveByUserID mock is already set by Set") + } + + if mmFindActiveByUserID.defaultExpectation == nil { + mmFindActiveByUserID.defaultExpectation = &InviteRepositoryMockFindActiveByUserIDExpectation{} + } + + if mmFindActiveByUserID.defaultExpectation.paramPtrs != nil { + mmFindActiveByUserID.mock.t.Fatalf("InviteRepositoryMock.FindActiveByUserID mock is already set by ExpectParams functions") + } + + mmFindActiveByUserID.defaultExpectation.params = &InviteRepositoryMockFindActiveByUserIDParams{ctx, userID} + mmFindActiveByUserID.defaultExpectation.expectationOrigins.origin = minimock.CallerInfo(1) + for _, e := range mmFindActiveByUserID.expectations { + if minimock.Equal(e.params, mmFindActiveByUserID.defaultExpectation.params) { + mmFindActiveByUserID.mock.t.Fatalf("Expectation set by When has same params: %#v", *mmFindActiveByUserID.defaultExpectation.params) + } + } + + return mmFindActiveByUserID +} + +// ExpectCtxParam1 sets up expected param ctx for InviteRepository.FindActiveByUserID +func (mmFindActiveByUserID *mInviteRepositoryMockFindActiveByUserID) ExpectCtxParam1(ctx context.Context) *mInviteRepositoryMockFindActiveByUserID { + if mmFindActiveByUserID.mock.funcFindActiveByUserID != nil { + mmFindActiveByUserID.mock.t.Fatalf("InviteRepositoryMock.FindActiveByUserID mock is already set by Set") + } + + if mmFindActiveByUserID.defaultExpectation == nil { + mmFindActiveByUserID.defaultExpectation = &InviteRepositoryMockFindActiveByUserIDExpectation{} + } + + if mmFindActiveByUserID.defaultExpectation.params != nil { + mmFindActiveByUserID.mock.t.Fatalf("InviteRepositoryMock.FindActiveByUserID mock is already set by Expect") + } + + if mmFindActiveByUserID.defaultExpectation.paramPtrs == nil { + mmFindActiveByUserID.defaultExpectation.paramPtrs = &InviteRepositoryMockFindActiveByUserIDParamPtrs{} + } + mmFindActiveByUserID.defaultExpectation.paramPtrs.ctx = &ctx + mmFindActiveByUserID.defaultExpectation.expectationOrigins.originCtx = minimock.CallerInfo(1) + + return mmFindActiveByUserID +} + +// ExpectUserIDParam2 sets up expected param userID for InviteRepository.FindActiveByUserID +func (mmFindActiveByUserID *mInviteRepositoryMockFindActiveByUserID) ExpectUserIDParam2(userID int) *mInviteRepositoryMockFindActiveByUserID { + if mmFindActiveByUserID.mock.funcFindActiveByUserID != nil { + mmFindActiveByUserID.mock.t.Fatalf("InviteRepositoryMock.FindActiveByUserID mock is already set by Set") + } + + if mmFindActiveByUserID.defaultExpectation == nil { + mmFindActiveByUserID.defaultExpectation = &InviteRepositoryMockFindActiveByUserIDExpectation{} + } + + if mmFindActiveByUserID.defaultExpectation.params != nil { + mmFindActiveByUserID.mock.t.Fatalf("InviteRepositoryMock.FindActiveByUserID mock is already set by Expect") + } + + if mmFindActiveByUserID.defaultExpectation.paramPtrs == nil { + mmFindActiveByUserID.defaultExpectation.paramPtrs = &InviteRepositoryMockFindActiveByUserIDParamPtrs{} + } + mmFindActiveByUserID.defaultExpectation.paramPtrs.userID = &userID + mmFindActiveByUserID.defaultExpectation.expectationOrigins.originUserID = minimock.CallerInfo(1) + + return mmFindActiveByUserID +} + +// Inspect accepts an inspector function that has same arguments as the InviteRepository.FindActiveByUserID +func (mmFindActiveByUserID *mInviteRepositoryMockFindActiveByUserID) Inspect(f func(ctx context.Context, userID int)) *mInviteRepositoryMockFindActiveByUserID { + if mmFindActiveByUserID.mock.inspectFuncFindActiveByUserID != nil { + mmFindActiveByUserID.mock.t.Fatalf("Inspect function is already set for InviteRepositoryMock.FindActiveByUserID") + } + + mmFindActiveByUserID.mock.inspectFuncFindActiveByUserID = f + + return mmFindActiveByUserID +} + +// Return sets up results that will be returned by InviteRepository.FindActiveByUserID +func (mmFindActiveByUserID *mInviteRepositoryMockFindActiveByUserID) Return(ip1 *model.InviteCode, err error) *InviteRepositoryMock { + if mmFindActiveByUserID.mock.funcFindActiveByUserID != nil { + mmFindActiveByUserID.mock.t.Fatalf("InviteRepositoryMock.FindActiveByUserID mock is already set by Set") + } + + if mmFindActiveByUserID.defaultExpectation == nil { + mmFindActiveByUserID.defaultExpectation = &InviteRepositoryMockFindActiveByUserIDExpectation{mock: mmFindActiveByUserID.mock} + } + mmFindActiveByUserID.defaultExpectation.results = &InviteRepositoryMockFindActiveByUserIDResults{ip1, err} + mmFindActiveByUserID.defaultExpectation.returnOrigin = minimock.CallerInfo(1) + return mmFindActiveByUserID.mock +} + +// Set uses given function f to mock the InviteRepository.FindActiveByUserID method +func (mmFindActiveByUserID *mInviteRepositoryMockFindActiveByUserID) Set(f func(ctx context.Context, userID int) (ip1 *model.InviteCode, err error)) *InviteRepositoryMock { + if mmFindActiveByUserID.defaultExpectation != nil { + mmFindActiveByUserID.mock.t.Fatalf("Default expectation is already set for the InviteRepository.FindActiveByUserID method") + } + + if len(mmFindActiveByUserID.expectations) > 0 { + mmFindActiveByUserID.mock.t.Fatalf("Some expectations are already set for the InviteRepository.FindActiveByUserID method") + } + + mmFindActiveByUserID.mock.funcFindActiveByUserID = f + mmFindActiveByUserID.mock.funcFindActiveByUserIDOrigin = minimock.CallerInfo(1) + return mmFindActiveByUserID.mock +} + +// When sets expectation for the InviteRepository.FindActiveByUserID which will trigger the result defined by the following +// Then helper +func (mmFindActiveByUserID *mInviteRepositoryMockFindActiveByUserID) When(ctx context.Context, userID int) *InviteRepositoryMockFindActiveByUserIDExpectation { + if mmFindActiveByUserID.mock.funcFindActiveByUserID != nil { + mmFindActiveByUserID.mock.t.Fatalf("InviteRepositoryMock.FindActiveByUserID mock is already set by Set") + } + + expectation := &InviteRepositoryMockFindActiveByUserIDExpectation{ + mock: mmFindActiveByUserID.mock, + params: &InviteRepositoryMockFindActiveByUserIDParams{ctx, userID}, + expectationOrigins: InviteRepositoryMockFindActiveByUserIDExpectationOrigins{origin: minimock.CallerInfo(1)}, + } + mmFindActiveByUserID.expectations = append(mmFindActiveByUserID.expectations, expectation) + return expectation +} + +// Then sets up InviteRepository.FindActiveByUserID return parameters for the expectation previously defined by the When method +func (e *InviteRepositoryMockFindActiveByUserIDExpectation) Then(ip1 *model.InviteCode, err error) *InviteRepositoryMock { + e.results = &InviteRepositoryMockFindActiveByUserIDResults{ip1, err} + return e.mock +} + +// Times sets number of times InviteRepository.FindActiveByUserID should be invoked +func (mmFindActiveByUserID *mInviteRepositoryMockFindActiveByUserID) Times(n uint64) *mInviteRepositoryMockFindActiveByUserID { + if n == 0 { + mmFindActiveByUserID.mock.t.Fatalf("Times of InviteRepositoryMock.FindActiveByUserID mock can not be zero") + } + mm_atomic.StoreUint64(&mmFindActiveByUserID.expectedInvocations, n) + mmFindActiveByUserID.expectedInvocationsOrigin = minimock.CallerInfo(1) + return mmFindActiveByUserID +} + +func (mmFindActiveByUserID *mInviteRepositoryMockFindActiveByUserID) invocationsDone() bool { + if len(mmFindActiveByUserID.expectations) == 0 && mmFindActiveByUserID.defaultExpectation == nil && mmFindActiveByUserID.mock.funcFindActiveByUserID == nil { + return true + } + + totalInvocations := mm_atomic.LoadUint64(&mmFindActiveByUserID.mock.afterFindActiveByUserIDCounter) + expectedInvocations := mm_atomic.LoadUint64(&mmFindActiveByUserID.expectedInvocations) + + return totalInvocations > 0 && (expectedInvocations == 0 || expectedInvocations == totalInvocations) +} + +// FindActiveByUserID implements mm_repository.InviteRepository +func (mmFindActiveByUserID *InviteRepositoryMock) FindActiveByUserID(ctx context.Context, userID int) (ip1 *model.InviteCode, err error) { + mm_atomic.AddUint64(&mmFindActiveByUserID.beforeFindActiveByUserIDCounter, 1) + defer mm_atomic.AddUint64(&mmFindActiveByUserID.afterFindActiveByUserIDCounter, 1) + + mmFindActiveByUserID.t.Helper() + + if mmFindActiveByUserID.inspectFuncFindActiveByUserID != nil { + mmFindActiveByUserID.inspectFuncFindActiveByUserID(ctx, userID) + } + + mm_params := InviteRepositoryMockFindActiveByUserIDParams{ctx, userID} + + // Record call args + mmFindActiveByUserID.FindActiveByUserIDMock.mutex.Lock() + mmFindActiveByUserID.FindActiveByUserIDMock.callArgs = append(mmFindActiveByUserID.FindActiveByUserIDMock.callArgs, &mm_params) + mmFindActiveByUserID.FindActiveByUserIDMock.mutex.Unlock() + + for _, e := range mmFindActiveByUserID.FindActiveByUserIDMock.expectations { + if minimock.Equal(*e.params, mm_params) { + mm_atomic.AddUint64(&e.Counter, 1) + return e.results.ip1, e.results.err + } + } + + if mmFindActiveByUserID.FindActiveByUserIDMock.defaultExpectation != nil { + mm_atomic.AddUint64(&mmFindActiveByUserID.FindActiveByUserIDMock.defaultExpectation.Counter, 1) + mm_want := mmFindActiveByUserID.FindActiveByUserIDMock.defaultExpectation.params + mm_want_ptrs := mmFindActiveByUserID.FindActiveByUserIDMock.defaultExpectation.paramPtrs + + mm_got := InviteRepositoryMockFindActiveByUserIDParams{ctx, userID} + + if mm_want_ptrs != nil { + + if mm_want_ptrs.ctx != nil && !minimock.Equal(*mm_want_ptrs.ctx, mm_got.ctx) { + mmFindActiveByUserID.t.Errorf("InviteRepositoryMock.FindActiveByUserID got unexpected parameter ctx, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmFindActiveByUserID.FindActiveByUserIDMock.defaultExpectation.expectationOrigins.originCtx, *mm_want_ptrs.ctx, mm_got.ctx, minimock.Diff(*mm_want_ptrs.ctx, mm_got.ctx)) + } + + if mm_want_ptrs.userID != nil && !minimock.Equal(*mm_want_ptrs.userID, mm_got.userID) { + mmFindActiveByUserID.t.Errorf("InviteRepositoryMock.FindActiveByUserID got unexpected parameter userID, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmFindActiveByUserID.FindActiveByUserIDMock.defaultExpectation.expectationOrigins.originUserID, *mm_want_ptrs.userID, mm_got.userID, minimock.Diff(*mm_want_ptrs.userID, mm_got.userID)) + } + + } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { + mmFindActiveByUserID.t.Errorf("InviteRepositoryMock.FindActiveByUserID got unexpected parameters, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmFindActiveByUserID.FindActiveByUserIDMock.defaultExpectation.expectationOrigins.origin, *mm_want, mm_got, minimock.Diff(*mm_want, mm_got)) + } + + mm_results := mmFindActiveByUserID.FindActiveByUserIDMock.defaultExpectation.results + if mm_results == nil { + mmFindActiveByUserID.t.Fatal("No results are set for the InviteRepositoryMock.FindActiveByUserID") + } + return (*mm_results).ip1, (*mm_results).err + } + if mmFindActiveByUserID.funcFindActiveByUserID != nil { + return mmFindActiveByUserID.funcFindActiveByUserID(ctx, userID) + } + mmFindActiveByUserID.t.Fatalf("Unexpected call to InviteRepositoryMock.FindActiveByUserID. %v %v", ctx, userID) + return +} + +// FindActiveByUserIDAfterCounter returns a count of finished InviteRepositoryMock.FindActiveByUserID invocations +func (mmFindActiveByUserID *InviteRepositoryMock) FindActiveByUserIDAfterCounter() uint64 { + return mm_atomic.LoadUint64(&mmFindActiveByUserID.afterFindActiveByUserIDCounter) +} + +// FindActiveByUserIDBeforeCounter returns a count of InviteRepositoryMock.FindActiveByUserID invocations +func (mmFindActiveByUserID *InviteRepositoryMock) FindActiveByUserIDBeforeCounter() uint64 { + return mm_atomic.LoadUint64(&mmFindActiveByUserID.beforeFindActiveByUserIDCounter) +} + +// Calls returns a list of arguments used in each call to InviteRepositoryMock.FindActiveByUserID. +// The list is in the same order as the calls were made (i.e. recent calls have a higher index) +func (mmFindActiveByUserID *mInviteRepositoryMockFindActiveByUserID) Calls() []*InviteRepositoryMockFindActiveByUserIDParams { + mmFindActiveByUserID.mutex.RLock() + + argCopy := make([]*InviteRepositoryMockFindActiveByUserIDParams, len(mmFindActiveByUserID.callArgs)) + copy(argCopy, mmFindActiveByUserID.callArgs) + + mmFindActiveByUserID.mutex.RUnlock() + + return argCopy +} + +// MinimockFindActiveByUserIDDone returns true if the count of the FindActiveByUserID invocations corresponds +// the number of defined expectations +func (m *InviteRepositoryMock) MinimockFindActiveByUserIDDone() bool { + if m.FindActiveByUserIDMock.optional { + // Optional methods provide '0 or more' call count restriction. + return true + } + + for _, e := range m.FindActiveByUserIDMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + return false + } + } + + return m.FindActiveByUserIDMock.invocationsDone() +} + +// MinimockFindActiveByUserIDInspect logs each unmet expectation +func (m *InviteRepositoryMock) MinimockFindActiveByUserIDInspect() { + for _, e := range m.FindActiveByUserIDMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + m.t.Errorf("Expected call to InviteRepositoryMock.FindActiveByUserID at\n%s with params: %#v", e.expectationOrigins.origin, *e.params) + } + } + + afterFindActiveByUserIDCounter := mm_atomic.LoadUint64(&m.afterFindActiveByUserIDCounter) + // if default expectation was set then invocations count should be greater than zero + if m.FindActiveByUserIDMock.defaultExpectation != nil && afterFindActiveByUserIDCounter < 1 { + if m.FindActiveByUserIDMock.defaultExpectation.params == nil { + m.t.Errorf("Expected call to InviteRepositoryMock.FindActiveByUserID at\n%s", m.FindActiveByUserIDMock.defaultExpectation.returnOrigin) + } else { + m.t.Errorf("Expected call to InviteRepositoryMock.FindActiveByUserID at\n%s with params: %#v", m.FindActiveByUserIDMock.defaultExpectation.expectationOrigins.origin, *m.FindActiveByUserIDMock.defaultExpectation.params) + } + } + // if func was set then invocations count should be greater than zero + if m.funcFindActiveByUserID != nil && afterFindActiveByUserIDCounter < 1 { + m.t.Errorf("Expected call to InviteRepositoryMock.FindActiveByUserID at\n%s", m.funcFindActiveByUserIDOrigin) + } + + if !m.FindActiveByUserIDMock.invocationsDone() && afterFindActiveByUserIDCounter > 0 { + m.t.Errorf("Expected %d calls to InviteRepositoryMock.FindActiveByUserID at\n%s but found %d calls", + mm_atomic.LoadUint64(&m.FindActiveByUserIDMock.expectedInvocations), m.FindActiveByUserIDMock.expectedInvocationsOrigin, afterFindActiveByUserIDCounter) + } +} + type mInviteRepositoryMockFindByCode struct { optional bool mock *InviteRepositoryMock @@ -2543,348 +2886,6 @@ func (m *InviteRepositoryMock) MinimockGetUserInvitesInspect() { } } -type mInviteRepositoryMockIncrementUsedCount struct { - optional bool - mock *InviteRepositoryMock - defaultExpectation *InviteRepositoryMockIncrementUsedCountExpectation - expectations []*InviteRepositoryMockIncrementUsedCountExpectation - - callArgs []*InviteRepositoryMockIncrementUsedCountParams - mutex sync.RWMutex - - expectedInvocations uint64 - expectedInvocationsOrigin string -} - -// InviteRepositoryMockIncrementUsedCountExpectation specifies expectation struct of the InviteRepository.IncrementUsedCount -type InviteRepositoryMockIncrementUsedCountExpectation struct { - mock *InviteRepositoryMock - params *InviteRepositoryMockIncrementUsedCountParams - paramPtrs *InviteRepositoryMockIncrementUsedCountParamPtrs - expectationOrigins InviteRepositoryMockIncrementUsedCountExpectationOrigins - results *InviteRepositoryMockIncrementUsedCountResults - returnOrigin string - Counter uint64 -} - -// InviteRepositoryMockIncrementUsedCountParams contains parameters of the InviteRepository.IncrementUsedCount -type InviteRepositoryMockIncrementUsedCountParams struct { - ctx context.Context - code int64 -} - -// InviteRepositoryMockIncrementUsedCountParamPtrs contains pointers to parameters of the InviteRepository.IncrementUsedCount -type InviteRepositoryMockIncrementUsedCountParamPtrs struct { - ctx *context.Context - code *int64 -} - -// InviteRepositoryMockIncrementUsedCountResults contains results of the InviteRepository.IncrementUsedCount -type InviteRepositoryMockIncrementUsedCountResults struct { - err error -} - -// InviteRepositoryMockIncrementUsedCountOrigins contains origins of expectations of the InviteRepository.IncrementUsedCount -type InviteRepositoryMockIncrementUsedCountExpectationOrigins struct { - origin string - originCtx string - originCode string -} - -// Marks this method to be optional. The default behavior of any method with Return() is '1 or more', meaning -// the test will fail minimock's automatic final call check if the mocked method was not called at least once. -// Optional() makes method check to work in '0 or more' mode. -// It is NOT RECOMMENDED to use this option unless you really need it, as default behaviour helps to -// catch the problems when the expected method call is totally skipped during test run. -func (mmIncrementUsedCount *mInviteRepositoryMockIncrementUsedCount) Optional() *mInviteRepositoryMockIncrementUsedCount { - mmIncrementUsedCount.optional = true - return mmIncrementUsedCount -} - -// Expect sets up expected params for InviteRepository.IncrementUsedCount -func (mmIncrementUsedCount *mInviteRepositoryMockIncrementUsedCount) Expect(ctx context.Context, code int64) *mInviteRepositoryMockIncrementUsedCount { - if mmIncrementUsedCount.mock.funcIncrementUsedCount != nil { - mmIncrementUsedCount.mock.t.Fatalf("InviteRepositoryMock.IncrementUsedCount mock is already set by Set") - } - - if mmIncrementUsedCount.defaultExpectation == nil { - mmIncrementUsedCount.defaultExpectation = &InviteRepositoryMockIncrementUsedCountExpectation{} - } - - if mmIncrementUsedCount.defaultExpectation.paramPtrs != nil { - mmIncrementUsedCount.mock.t.Fatalf("InviteRepositoryMock.IncrementUsedCount mock is already set by ExpectParams functions") - } - - mmIncrementUsedCount.defaultExpectation.params = &InviteRepositoryMockIncrementUsedCountParams{ctx, code} - mmIncrementUsedCount.defaultExpectation.expectationOrigins.origin = minimock.CallerInfo(1) - for _, e := range mmIncrementUsedCount.expectations { - if minimock.Equal(e.params, mmIncrementUsedCount.defaultExpectation.params) { - mmIncrementUsedCount.mock.t.Fatalf("Expectation set by When has same params: %#v", *mmIncrementUsedCount.defaultExpectation.params) - } - } - - return mmIncrementUsedCount -} - -// ExpectCtxParam1 sets up expected param ctx for InviteRepository.IncrementUsedCount -func (mmIncrementUsedCount *mInviteRepositoryMockIncrementUsedCount) ExpectCtxParam1(ctx context.Context) *mInviteRepositoryMockIncrementUsedCount { - if mmIncrementUsedCount.mock.funcIncrementUsedCount != nil { - mmIncrementUsedCount.mock.t.Fatalf("InviteRepositoryMock.IncrementUsedCount mock is already set by Set") - } - - if mmIncrementUsedCount.defaultExpectation == nil { - mmIncrementUsedCount.defaultExpectation = &InviteRepositoryMockIncrementUsedCountExpectation{} - } - - if mmIncrementUsedCount.defaultExpectation.params != nil { - mmIncrementUsedCount.mock.t.Fatalf("InviteRepositoryMock.IncrementUsedCount mock is already set by Expect") - } - - if mmIncrementUsedCount.defaultExpectation.paramPtrs == nil { - mmIncrementUsedCount.defaultExpectation.paramPtrs = &InviteRepositoryMockIncrementUsedCountParamPtrs{} - } - mmIncrementUsedCount.defaultExpectation.paramPtrs.ctx = &ctx - mmIncrementUsedCount.defaultExpectation.expectationOrigins.originCtx = minimock.CallerInfo(1) - - return mmIncrementUsedCount -} - -// ExpectCodeParam2 sets up expected param code for InviteRepository.IncrementUsedCount -func (mmIncrementUsedCount *mInviteRepositoryMockIncrementUsedCount) ExpectCodeParam2(code int64) *mInviteRepositoryMockIncrementUsedCount { - if mmIncrementUsedCount.mock.funcIncrementUsedCount != nil { - mmIncrementUsedCount.mock.t.Fatalf("InviteRepositoryMock.IncrementUsedCount mock is already set by Set") - } - - if mmIncrementUsedCount.defaultExpectation == nil { - mmIncrementUsedCount.defaultExpectation = &InviteRepositoryMockIncrementUsedCountExpectation{} - } - - if mmIncrementUsedCount.defaultExpectation.params != nil { - mmIncrementUsedCount.mock.t.Fatalf("InviteRepositoryMock.IncrementUsedCount mock is already set by Expect") - } - - if mmIncrementUsedCount.defaultExpectation.paramPtrs == nil { - mmIncrementUsedCount.defaultExpectation.paramPtrs = &InviteRepositoryMockIncrementUsedCountParamPtrs{} - } - mmIncrementUsedCount.defaultExpectation.paramPtrs.code = &code - mmIncrementUsedCount.defaultExpectation.expectationOrigins.originCode = minimock.CallerInfo(1) - - return mmIncrementUsedCount -} - -// Inspect accepts an inspector function that has same arguments as the InviteRepository.IncrementUsedCount -func (mmIncrementUsedCount *mInviteRepositoryMockIncrementUsedCount) Inspect(f func(ctx context.Context, code int64)) *mInviteRepositoryMockIncrementUsedCount { - if mmIncrementUsedCount.mock.inspectFuncIncrementUsedCount != nil { - mmIncrementUsedCount.mock.t.Fatalf("Inspect function is already set for InviteRepositoryMock.IncrementUsedCount") - } - - mmIncrementUsedCount.mock.inspectFuncIncrementUsedCount = f - - return mmIncrementUsedCount -} - -// Return sets up results that will be returned by InviteRepository.IncrementUsedCount -func (mmIncrementUsedCount *mInviteRepositoryMockIncrementUsedCount) Return(err error) *InviteRepositoryMock { - if mmIncrementUsedCount.mock.funcIncrementUsedCount != nil { - mmIncrementUsedCount.mock.t.Fatalf("InviteRepositoryMock.IncrementUsedCount mock is already set by Set") - } - - if mmIncrementUsedCount.defaultExpectation == nil { - mmIncrementUsedCount.defaultExpectation = &InviteRepositoryMockIncrementUsedCountExpectation{mock: mmIncrementUsedCount.mock} - } - mmIncrementUsedCount.defaultExpectation.results = &InviteRepositoryMockIncrementUsedCountResults{err} - mmIncrementUsedCount.defaultExpectation.returnOrigin = minimock.CallerInfo(1) - return mmIncrementUsedCount.mock -} - -// Set uses given function f to mock the InviteRepository.IncrementUsedCount method -func (mmIncrementUsedCount *mInviteRepositoryMockIncrementUsedCount) Set(f func(ctx context.Context, code int64) (err error)) *InviteRepositoryMock { - if mmIncrementUsedCount.defaultExpectation != nil { - mmIncrementUsedCount.mock.t.Fatalf("Default expectation is already set for the InviteRepository.IncrementUsedCount method") - } - - if len(mmIncrementUsedCount.expectations) > 0 { - mmIncrementUsedCount.mock.t.Fatalf("Some expectations are already set for the InviteRepository.IncrementUsedCount method") - } - - mmIncrementUsedCount.mock.funcIncrementUsedCount = f - mmIncrementUsedCount.mock.funcIncrementUsedCountOrigin = minimock.CallerInfo(1) - return mmIncrementUsedCount.mock -} - -// When sets expectation for the InviteRepository.IncrementUsedCount which will trigger the result defined by the following -// Then helper -func (mmIncrementUsedCount *mInviteRepositoryMockIncrementUsedCount) When(ctx context.Context, code int64) *InviteRepositoryMockIncrementUsedCountExpectation { - if mmIncrementUsedCount.mock.funcIncrementUsedCount != nil { - mmIncrementUsedCount.mock.t.Fatalf("InviteRepositoryMock.IncrementUsedCount mock is already set by Set") - } - - expectation := &InviteRepositoryMockIncrementUsedCountExpectation{ - mock: mmIncrementUsedCount.mock, - params: &InviteRepositoryMockIncrementUsedCountParams{ctx, code}, - expectationOrigins: InviteRepositoryMockIncrementUsedCountExpectationOrigins{origin: minimock.CallerInfo(1)}, - } - mmIncrementUsedCount.expectations = append(mmIncrementUsedCount.expectations, expectation) - return expectation -} - -// Then sets up InviteRepository.IncrementUsedCount return parameters for the expectation previously defined by the When method -func (e *InviteRepositoryMockIncrementUsedCountExpectation) Then(err error) *InviteRepositoryMock { - e.results = &InviteRepositoryMockIncrementUsedCountResults{err} - return e.mock -} - -// Times sets number of times InviteRepository.IncrementUsedCount should be invoked -func (mmIncrementUsedCount *mInviteRepositoryMockIncrementUsedCount) Times(n uint64) *mInviteRepositoryMockIncrementUsedCount { - if n == 0 { - mmIncrementUsedCount.mock.t.Fatalf("Times of InviteRepositoryMock.IncrementUsedCount mock can not be zero") - } - mm_atomic.StoreUint64(&mmIncrementUsedCount.expectedInvocations, n) - mmIncrementUsedCount.expectedInvocationsOrigin = minimock.CallerInfo(1) - return mmIncrementUsedCount -} - -func (mmIncrementUsedCount *mInviteRepositoryMockIncrementUsedCount) invocationsDone() bool { - if len(mmIncrementUsedCount.expectations) == 0 && mmIncrementUsedCount.defaultExpectation == nil && mmIncrementUsedCount.mock.funcIncrementUsedCount == nil { - return true - } - - totalInvocations := mm_atomic.LoadUint64(&mmIncrementUsedCount.mock.afterIncrementUsedCountCounter) - expectedInvocations := mm_atomic.LoadUint64(&mmIncrementUsedCount.expectedInvocations) - - return totalInvocations > 0 && (expectedInvocations == 0 || expectedInvocations == totalInvocations) -} - -// IncrementUsedCount implements mm_repository.InviteRepository -func (mmIncrementUsedCount *InviteRepositoryMock) IncrementUsedCount(ctx context.Context, code int64) (err error) { - mm_atomic.AddUint64(&mmIncrementUsedCount.beforeIncrementUsedCountCounter, 1) - defer mm_atomic.AddUint64(&mmIncrementUsedCount.afterIncrementUsedCountCounter, 1) - - mmIncrementUsedCount.t.Helper() - - if mmIncrementUsedCount.inspectFuncIncrementUsedCount != nil { - mmIncrementUsedCount.inspectFuncIncrementUsedCount(ctx, code) - } - - mm_params := InviteRepositoryMockIncrementUsedCountParams{ctx, code} - - // Record call args - mmIncrementUsedCount.IncrementUsedCountMock.mutex.Lock() - mmIncrementUsedCount.IncrementUsedCountMock.callArgs = append(mmIncrementUsedCount.IncrementUsedCountMock.callArgs, &mm_params) - mmIncrementUsedCount.IncrementUsedCountMock.mutex.Unlock() - - for _, e := range mmIncrementUsedCount.IncrementUsedCountMock.expectations { - if minimock.Equal(*e.params, mm_params) { - mm_atomic.AddUint64(&e.Counter, 1) - return e.results.err - } - } - - if mmIncrementUsedCount.IncrementUsedCountMock.defaultExpectation != nil { - mm_atomic.AddUint64(&mmIncrementUsedCount.IncrementUsedCountMock.defaultExpectation.Counter, 1) - mm_want := mmIncrementUsedCount.IncrementUsedCountMock.defaultExpectation.params - mm_want_ptrs := mmIncrementUsedCount.IncrementUsedCountMock.defaultExpectation.paramPtrs - - mm_got := InviteRepositoryMockIncrementUsedCountParams{ctx, code} - - if mm_want_ptrs != nil { - - if mm_want_ptrs.ctx != nil && !minimock.Equal(*mm_want_ptrs.ctx, mm_got.ctx) { - mmIncrementUsedCount.t.Errorf("InviteRepositoryMock.IncrementUsedCount got unexpected parameter ctx, expected at\n%s:\nwant: %#v\n got: %#v%s\n", - mmIncrementUsedCount.IncrementUsedCountMock.defaultExpectation.expectationOrigins.originCtx, *mm_want_ptrs.ctx, mm_got.ctx, minimock.Diff(*mm_want_ptrs.ctx, mm_got.ctx)) - } - - if mm_want_ptrs.code != nil && !minimock.Equal(*mm_want_ptrs.code, mm_got.code) { - mmIncrementUsedCount.t.Errorf("InviteRepositoryMock.IncrementUsedCount got unexpected parameter code, expected at\n%s:\nwant: %#v\n got: %#v%s\n", - mmIncrementUsedCount.IncrementUsedCountMock.defaultExpectation.expectationOrigins.originCode, *mm_want_ptrs.code, mm_got.code, minimock.Diff(*mm_want_ptrs.code, mm_got.code)) - } - - } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { - mmIncrementUsedCount.t.Errorf("InviteRepositoryMock.IncrementUsedCount got unexpected parameters, expected at\n%s:\nwant: %#v\n got: %#v%s\n", - mmIncrementUsedCount.IncrementUsedCountMock.defaultExpectation.expectationOrigins.origin, *mm_want, mm_got, minimock.Diff(*mm_want, mm_got)) - } - - mm_results := mmIncrementUsedCount.IncrementUsedCountMock.defaultExpectation.results - if mm_results == nil { - mmIncrementUsedCount.t.Fatal("No results are set for the InviteRepositoryMock.IncrementUsedCount") - } - return (*mm_results).err - } - if mmIncrementUsedCount.funcIncrementUsedCount != nil { - return mmIncrementUsedCount.funcIncrementUsedCount(ctx, code) - } - mmIncrementUsedCount.t.Fatalf("Unexpected call to InviteRepositoryMock.IncrementUsedCount. %v %v", ctx, code) - return -} - -// IncrementUsedCountAfterCounter returns a count of finished InviteRepositoryMock.IncrementUsedCount invocations -func (mmIncrementUsedCount *InviteRepositoryMock) IncrementUsedCountAfterCounter() uint64 { - return mm_atomic.LoadUint64(&mmIncrementUsedCount.afterIncrementUsedCountCounter) -} - -// IncrementUsedCountBeforeCounter returns a count of InviteRepositoryMock.IncrementUsedCount invocations -func (mmIncrementUsedCount *InviteRepositoryMock) IncrementUsedCountBeforeCounter() uint64 { - return mm_atomic.LoadUint64(&mmIncrementUsedCount.beforeIncrementUsedCountCounter) -} - -// Calls returns a list of arguments used in each call to InviteRepositoryMock.IncrementUsedCount. -// The list is in the same order as the calls were made (i.e. recent calls have a higher index) -func (mmIncrementUsedCount *mInviteRepositoryMockIncrementUsedCount) Calls() []*InviteRepositoryMockIncrementUsedCountParams { - mmIncrementUsedCount.mutex.RLock() - - argCopy := make([]*InviteRepositoryMockIncrementUsedCountParams, len(mmIncrementUsedCount.callArgs)) - copy(argCopy, mmIncrementUsedCount.callArgs) - - mmIncrementUsedCount.mutex.RUnlock() - - return argCopy -} - -// MinimockIncrementUsedCountDone returns true if the count of the IncrementUsedCount invocations corresponds -// the number of defined expectations -func (m *InviteRepositoryMock) MinimockIncrementUsedCountDone() bool { - if m.IncrementUsedCountMock.optional { - // Optional methods provide '0 or more' call count restriction. - return true - } - - for _, e := range m.IncrementUsedCountMock.expectations { - if mm_atomic.LoadUint64(&e.Counter) < 1 { - return false - } - } - - return m.IncrementUsedCountMock.invocationsDone() -} - -// MinimockIncrementUsedCountInspect logs each unmet expectation -func (m *InviteRepositoryMock) MinimockIncrementUsedCountInspect() { - for _, e := range m.IncrementUsedCountMock.expectations { - if mm_atomic.LoadUint64(&e.Counter) < 1 { - m.t.Errorf("Expected call to InviteRepositoryMock.IncrementUsedCount at\n%s with params: %#v", e.expectationOrigins.origin, *e.params) - } - } - - afterIncrementUsedCountCounter := mm_atomic.LoadUint64(&m.afterIncrementUsedCountCounter) - // if default expectation was set then invocations count should be greater than zero - if m.IncrementUsedCountMock.defaultExpectation != nil && afterIncrementUsedCountCounter < 1 { - if m.IncrementUsedCountMock.defaultExpectation.params == nil { - m.t.Errorf("Expected call to InviteRepositoryMock.IncrementUsedCount at\n%s", m.IncrementUsedCountMock.defaultExpectation.returnOrigin) - } else { - m.t.Errorf("Expected call to InviteRepositoryMock.IncrementUsedCount at\n%s with params: %#v", m.IncrementUsedCountMock.defaultExpectation.expectationOrigins.origin, *m.IncrementUsedCountMock.defaultExpectation.params) - } - } - // if func was set then invocations count should be greater than zero - if m.funcIncrementUsedCount != nil && afterIncrementUsedCountCounter < 1 { - m.t.Errorf("Expected call to InviteRepositoryMock.IncrementUsedCount at\n%s", m.funcIncrementUsedCountOrigin) - } - - if !m.IncrementUsedCountMock.invocationsDone() && afterIncrementUsedCountCounter > 0 { - m.t.Errorf("Expected %d calls to InviteRepositoryMock.IncrementUsedCount at\n%s but found %d calls", - mm_atomic.LoadUint64(&m.IncrementUsedCountMock.expectedInvocations), m.IncrementUsedCountMock.expectedInvocationsOrigin, afterIncrementUsedCountCounter) - } -} - // MinimockFinish checks that all mocked methods have been called the expected number of times func (m *InviteRepositoryMock) MinimockFinish() { m.finishOnce.Do(func() { @@ -2899,11 +2900,11 @@ func (m *InviteRepositoryMock) MinimockFinish() { m.MinimockFindActiveByCodeInspect() + m.MinimockFindActiveByUserIDInspect() + m.MinimockFindByCodeInspect() m.MinimockGetUserInvitesInspect() - - m.MinimockIncrementUsedCountInspect() } }) } @@ -2932,7 +2933,7 @@ func (m *InviteRepositoryMock) minimockDone() bool { m.MinimockDeactivateExpiredDone() && m.MinimockDecrementCanBeUsedCountTxDone() && m.MinimockFindActiveByCodeDone() && + m.MinimockFindActiveByUserIDDone() && m.MinimockFindByCodeDone() && - m.MinimockGetUserInvitesDone() && - m.MinimockIncrementUsedCountDone() + m.MinimockGetUserInvitesDone() } diff --git a/internal/mocks/invite_service_mock.go b/internal/mocks/invite_service_mock.go index beee5c7..5de7e05 100644 --- a/internal/mocks/invite_service_mock.go +++ b/internal/mocks/invite_service_mock.go @@ -26,9 +26,9 @@ type InviteServiceMock struct { beforeGenerateCounter uint64 GenerateMock mInviteServiceMockGenerate - funcGetInfo func(ctx context.Context, code int64) (ip1 *model.InviteCode, err error) + funcGetInfo func(ctx context.Context, userID int) (ip1 *model.InviteCode, err error) funcGetInfoOrigin string - inspectFuncGetInfo func(ctx context.Context, code int64) + inspectFuncGetInfo func(ctx context.Context, userID int) afterGetInfoCounter uint64 beforeGetInfoCounter uint64 GetInfoMock mInviteServiceMockGetInfo @@ -484,14 +484,14 @@ type InviteServiceMockGetInfoExpectation struct { // InviteServiceMockGetInfoParams contains parameters of the InviteService.GetInfo type InviteServiceMockGetInfoParams struct { - ctx context.Context - code int64 + ctx context.Context + userID int } // InviteServiceMockGetInfoParamPtrs contains pointers to parameters of the InviteService.GetInfo type InviteServiceMockGetInfoParamPtrs struct { - ctx *context.Context - code *int64 + ctx *context.Context + userID *int } // InviteServiceMockGetInfoResults contains results of the InviteService.GetInfo @@ -502,9 +502,9 @@ type InviteServiceMockGetInfoResults struct { // InviteServiceMockGetInfoOrigins contains origins of expectations of the InviteService.GetInfo type InviteServiceMockGetInfoExpectationOrigins struct { - origin string - originCtx string - originCode string + origin string + originCtx string + originUserID string } // Marks this method to be optional. The default behavior of any method with Return() is '1 or more', meaning @@ -518,7 +518,7 @@ func (mmGetInfo *mInviteServiceMockGetInfo) Optional() *mInviteServiceMockGetInf } // Expect sets up expected params for InviteService.GetInfo -func (mmGetInfo *mInviteServiceMockGetInfo) Expect(ctx context.Context, code int64) *mInviteServiceMockGetInfo { +func (mmGetInfo *mInviteServiceMockGetInfo) Expect(ctx context.Context, userID int) *mInviteServiceMockGetInfo { if mmGetInfo.mock.funcGetInfo != nil { mmGetInfo.mock.t.Fatalf("InviteServiceMock.GetInfo mock is already set by Set") } @@ -531,7 +531,7 @@ func (mmGetInfo *mInviteServiceMockGetInfo) Expect(ctx context.Context, code int mmGetInfo.mock.t.Fatalf("InviteServiceMock.GetInfo mock is already set by ExpectParams functions") } - mmGetInfo.defaultExpectation.params = &InviteServiceMockGetInfoParams{ctx, code} + mmGetInfo.defaultExpectation.params = &InviteServiceMockGetInfoParams{ctx, userID} mmGetInfo.defaultExpectation.expectationOrigins.origin = minimock.CallerInfo(1) for _, e := range mmGetInfo.expectations { if minimock.Equal(e.params, mmGetInfo.defaultExpectation.params) { @@ -565,8 +565,8 @@ func (mmGetInfo *mInviteServiceMockGetInfo) ExpectCtxParam1(ctx context.Context) return mmGetInfo } -// ExpectCodeParam2 sets up expected param code for InviteService.GetInfo -func (mmGetInfo *mInviteServiceMockGetInfo) ExpectCodeParam2(code int64) *mInviteServiceMockGetInfo { +// ExpectUserIDParam2 sets up expected param userID for InviteService.GetInfo +func (mmGetInfo *mInviteServiceMockGetInfo) ExpectUserIDParam2(userID int) *mInviteServiceMockGetInfo { if mmGetInfo.mock.funcGetInfo != nil { mmGetInfo.mock.t.Fatalf("InviteServiceMock.GetInfo mock is already set by Set") } @@ -582,14 +582,14 @@ func (mmGetInfo *mInviteServiceMockGetInfo) ExpectCodeParam2(code int64) *mInvit if mmGetInfo.defaultExpectation.paramPtrs == nil { mmGetInfo.defaultExpectation.paramPtrs = &InviteServiceMockGetInfoParamPtrs{} } - mmGetInfo.defaultExpectation.paramPtrs.code = &code - mmGetInfo.defaultExpectation.expectationOrigins.originCode = minimock.CallerInfo(1) + mmGetInfo.defaultExpectation.paramPtrs.userID = &userID + mmGetInfo.defaultExpectation.expectationOrigins.originUserID = minimock.CallerInfo(1) return mmGetInfo } // Inspect accepts an inspector function that has same arguments as the InviteService.GetInfo -func (mmGetInfo *mInviteServiceMockGetInfo) Inspect(f func(ctx context.Context, code int64)) *mInviteServiceMockGetInfo { +func (mmGetInfo *mInviteServiceMockGetInfo) Inspect(f func(ctx context.Context, userID int)) *mInviteServiceMockGetInfo { if mmGetInfo.mock.inspectFuncGetInfo != nil { mmGetInfo.mock.t.Fatalf("Inspect function is already set for InviteServiceMock.GetInfo") } @@ -614,7 +614,7 @@ func (mmGetInfo *mInviteServiceMockGetInfo) Return(ip1 *model.InviteCode, err er } // Set uses given function f to mock the InviteService.GetInfo method -func (mmGetInfo *mInviteServiceMockGetInfo) Set(f func(ctx context.Context, code int64) (ip1 *model.InviteCode, err error)) *InviteServiceMock { +func (mmGetInfo *mInviteServiceMockGetInfo) Set(f func(ctx context.Context, userID int) (ip1 *model.InviteCode, err error)) *InviteServiceMock { if mmGetInfo.defaultExpectation != nil { mmGetInfo.mock.t.Fatalf("Default expectation is already set for the InviteService.GetInfo method") } @@ -630,14 +630,14 @@ func (mmGetInfo *mInviteServiceMockGetInfo) Set(f func(ctx context.Context, code // When sets expectation for the InviteService.GetInfo which will trigger the result defined by the following // Then helper -func (mmGetInfo *mInviteServiceMockGetInfo) When(ctx context.Context, code int64) *InviteServiceMockGetInfoExpectation { +func (mmGetInfo *mInviteServiceMockGetInfo) When(ctx context.Context, userID int) *InviteServiceMockGetInfoExpectation { if mmGetInfo.mock.funcGetInfo != nil { mmGetInfo.mock.t.Fatalf("InviteServiceMock.GetInfo mock is already set by Set") } expectation := &InviteServiceMockGetInfoExpectation{ mock: mmGetInfo.mock, - params: &InviteServiceMockGetInfoParams{ctx, code}, + params: &InviteServiceMockGetInfoParams{ctx, userID}, expectationOrigins: InviteServiceMockGetInfoExpectationOrigins{origin: minimock.CallerInfo(1)}, } mmGetInfo.expectations = append(mmGetInfo.expectations, expectation) @@ -672,17 +672,17 @@ func (mmGetInfo *mInviteServiceMockGetInfo) invocationsDone() bool { } // GetInfo implements mm_service.InviteService -func (mmGetInfo *InviteServiceMock) GetInfo(ctx context.Context, code int64) (ip1 *model.InviteCode, err error) { +func (mmGetInfo *InviteServiceMock) GetInfo(ctx context.Context, userID int) (ip1 *model.InviteCode, err error) { mm_atomic.AddUint64(&mmGetInfo.beforeGetInfoCounter, 1) defer mm_atomic.AddUint64(&mmGetInfo.afterGetInfoCounter, 1) mmGetInfo.t.Helper() if mmGetInfo.inspectFuncGetInfo != nil { - mmGetInfo.inspectFuncGetInfo(ctx, code) + mmGetInfo.inspectFuncGetInfo(ctx, userID) } - mm_params := InviteServiceMockGetInfoParams{ctx, code} + mm_params := InviteServiceMockGetInfoParams{ctx, userID} // Record call args mmGetInfo.GetInfoMock.mutex.Lock() @@ -701,7 +701,7 @@ func (mmGetInfo *InviteServiceMock) GetInfo(ctx context.Context, code int64) (ip mm_want := mmGetInfo.GetInfoMock.defaultExpectation.params mm_want_ptrs := mmGetInfo.GetInfoMock.defaultExpectation.paramPtrs - mm_got := InviteServiceMockGetInfoParams{ctx, code} + mm_got := InviteServiceMockGetInfoParams{ctx, userID} if mm_want_ptrs != nil { @@ -710,9 +710,9 @@ func (mmGetInfo *InviteServiceMock) GetInfo(ctx context.Context, code int64) (ip mmGetInfo.GetInfoMock.defaultExpectation.expectationOrigins.originCtx, *mm_want_ptrs.ctx, mm_got.ctx, minimock.Diff(*mm_want_ptrs.ctx, mm_got.ctx)) } - if mm_want_ptrs.code != nil && !minimock.Equal(*mm_want_ptrs.code, mm_got.code) { - mmGetInfo.t.Errorf("InviteServiceMock.GetInfo got unexpected parameter code, expected at\n%s:\nwant: %#v\n got: %#v%s\n", - mmGetInfo.GetInfoMock.defaultExpectation.expectationOrigins.originCode, *mm_want_ptrs.code, mm_got.code, minimock.Diff(*mm_want_ptrs.code, mm_got.code)) + if mm_want_ptrs.userID != nil && !minimock.Equal(*mm_want_ptrs.userID, mm_got.userID) { + mmGetInfo.t.Errorf("InviteServiceMock.GetInfo got unexpected parameter userID, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmGetInfo.GetInfoMock.defaultExpectation.expectationOrigins.originUserID, *mm_want_ptrs.userID, mm_got.userID, minimock.Diff(*mm_want_ptrs.userID, mm_got.userID)) } } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { @@ -727,9 +727,9 @@ func (mmGetInfo *InviteServiceMock) GetInfo(ctx context.Context, code int64) (ip return (*mm_results).ip1, (*mm_results).err } if mmGetInfo.funcGetInfo != nil { - return mmGetInfo.funcGetInfo(ctx, code) + return mmGetInfo.funcGetInfo(ctx, userID) } - mmGetInfo.t.Fatalf("Unexpected call to InviteServiceMock.GetInfo. %v %v", ctx, code) + mmGetInfo.t.Fatalf("Unexpected call to InviteServiceMock.GetInfo. %v %v", ctx, userID) return } diff --git a/internal/mocks/request_repository_mock.go b/internal/mocks/request_repository_mock.go index 1bef473..76aad7f 100644 --- a/internal/mocks/request_repository_mock.go +++ b/internal/mocks/request_repository_mock.go @@ -21,6 +21,13 @@ type RequestRepositoryMock struct { t minimock.Tester finishOnce sync.Once + funcCheckOwnership func(ctx context.Context, requestID uuid.UUID, userID int) (b1 bool, err error) + funcCheckOwnershipOrigin string + inspectFuncCheckOwnership func(ctx context.Context, requestID uuid.UUID, userID int) + afterCheckOwnershipCounter uint64 + beforeCheckOwnershipCounter uint64 + CheckOwnershipMock mRequestRepositoryMockCheckOwnership + funcCreate func(ctx context.Context, req *model.Request) (err error) funcCreateOrigin string inspectFuncCreate func(ctx context.Context, req *model.Request) @@ -86,6 +93,9 @@ func NewRequestRepositoryMock(t minimock.Tester) *RequestRepositoryMock { controller.RegisterMocker(m) } + m.CheckOwnershipMock = mRequestRepositoryMockCheckOwnership{mock: m} + m.CheckOwnershipMock.callArgs = []*RequestRepositoryMockCheckOwnershipParams{} + m.CreateMock = mRequestRepositoryMockCreate{mock: m} m.CreateMock.callArgs = []*RequestRepositoryMockCreateParams{} @@ -115,6 +125,380 @@ func NewRequestRepositoryMock(t minimock.Tester) *RequestRepositoryMock { return m } +type mRequestRepositoryMockCheckOwnership struct { + optional bool + mock *RequestRepositoryMock + defaultExpectation *RequestRepositoryMockCheckOwnershipExpectation + expectations []*RequestRepositoryMockCheckOwnershipExpectation + + callArgs []*RequestRepositoryMockCheckOwnershipParams + mutex sync.RWMutex + + expectedInvocations uint64 + expectedInvocationsOrigin string +} + +// RequestRepositoryMockCheckOwnershipExpectation specifies expectation struct of the RequestRepository.CheckOwnership +type RequestRepositoryMockCheckOwnershipExpectation struct { + mock *RequestRepositoryMock + params *RequestRepositoryMockCheckOwnershipParams + paramPtrs *RequestRepositoryMockCheckOwnershipParamPtrs + expectationOrigins RequestRepositoryMockCheckOwnershipExpectationOrigins + results *RequestRepositoryMockCheckOwnershipResults + returnOrigin string + Counter uint64 +} + +// RequestRepositoryMockCheckOwnershipParams contains parameters of the RequestRepository.CheckOwnership +type RequestRepositoryMockCheckOwnershipParams struct { + ctx context.Context + requestID uuid.UUID + userID int +} + +// RequestRepositoryMockCheckOwnershipParamPtrs contains pointers to parameters of the RequestRepository.CheckOwnership +type RequestRepositoryMockCheckOwnershipParamPtrs struct { + ctx *context.Context + requestID *uuid.UUID + userID *int +} + +// RequestRepositoryMockCheckOwnershipResults contains results of the RequestRepository.CheckOwnership +type RequestRepositoryMockCheckOwnershipResults struct { + b1 bool + err error +} + +// RequestRepositoryMockCheckOwnershipOrigins contains origins of expectations of the RequestRepository.CheckOwnership +type RequestRepositoryMockCheckOwnershipExpectationOrigins struct { + origin string + originCtx string + originRequestID string + originUserID string +} + +// Marks this method to be optional. The default behavior of any method with Return() is '1 or more', meaning +// the test will fail minimock's automatic final call check if the mocked method was not called at least once. +// Optional() makes method check to work in '0 or more' mode. +// It is NOT RECOMMENDED to use this option unless you really need it, as default behaviour helps to +// catch the problems when the expected method call is totally skipped during test run. +func (mmCheckOwnership *mRequestRepositoryMockCheckOwnership) Optional() *mRequestRepositoryMockCheckOwnership { + mmCheckOwnership.optional = true + return mmCheckOwnership +} + +// Expect sets up expected params for RequestRepository.CheckOwnership +func (mmCheckOwnership *mRequestRepositoryMockCheckOwnership) Expect(ctx context.Context, requestID uuid.UUID, userID int) *mRequestRepositoryMockCheckOwnership { + if mmCheckOwnership.mock.funcCheckOwnership != nil { + mmCheckOwnership.mock.t.Fatalf("RequestRepositoryMock.CheckOwnership mock is already set by Set") + } + + if mmCheckOwnership.defaultExpectation == nil { + mmCheckOwnership.defaultExpectation = &RequestRepositoryMockCheckOwnershipExpectation{} + } + + if mmCheckOwnership.defaultExpectation.paramPtrs != nil { + mmCheckOwnership.mock.t.Fatalf("RequestRepositoryMock.CheckOwnership mock is already set by ExpectParams functions") + } + + mmCheckOwnership.defaultExpectation.params = &RequestRepositoryMockCheckOwnershipParams{ctx, requestID, userID} + mmCheckOwnership.defaultExpectation.expectationOrigins.origin = minimock.CallerInfo(1) + for _, e := range mmCheckOwnership.expectations { + if minimock.Equal(e.params, mmCheckOwnership.defaultExpectation.params) { + mmCheckOwnership.mock.t.Fatalf("Expectation set by When has same params: %#v", *mmCheckOwnership.defaultExpectation.params) + } + } + + return mmCheckOwnership +} + +// ExpectCtxParam1 sets up expected param ctx for RequestRepository.CheckOwnership +func (mmCheckOwnership *mRequestRepositoryMockCheckOwnership) ExpectCtxParam1(ctx context.Context) *mRequestRepositoryMockCheckOwnership { + if mmCheckOwnership.mock.funcCheckOwnership != nil { + mmCheckOwnership.mock.t.Fatalf("RequestRepositoryMock.CheckOwnership mock is already set by Set") + } + + if mmCheckOwnership.defaultExpectation == nil { + mmCheckOwnership.defaultExpectation = &RequestRepositoryMockCheckOwnershipExpectation{} + } + + if mmCheckOwnership.defaultExpectation.params != nil { + mmCheckOwnership.mock.t.Fatalf("RequestRepositoryMock.CheckOwnership mock is already set by Expect") + } + + if mmCheckOwnership.defaultExpectation.paramPtrs == nil { + mmCheckOwnership.defaultExpectation.paramPtrs = &RequestRepositoryMockCheckOwnershipParamPtrs{} + } + mmCheckOwnership.defaultExpectation.paramPtrs.ctx = &ctx + mmCheckOwnership.defaultExpectation.expectationOrigins.originCtx = minimock.CallerInfo(1) + + return mmCheckOwnership +} + +// ExpectRequestIDParam2 sets up expected param requestID for RequestRepository.CheckOwnership +func (mmCheckOwnership *mRequestRepositoryMockCheckOwnership) ExpectRequestIDParam2(requestID uuid.UUID) *mRequestRepositoryMockCheckOwnership { + if mmCheckOwnership.mock.funcCheckOwnership != nil { + mmCheckOwnership.mock.t.Fatalf("RequestRepositoryMock.CheckOwnership mock is already set by Set") + } + + if mmCheckOwnership.defaultExpectation == nil { + mmCheckOwnership.defaultExpectation = &RequestRepositoryMockCheckOwnershipExpectation{} + } + + if mmCheckOwnership.defaultExpectation.params != nil { + mmCheckOwnership.mock.t.Fatalf("RequestRepositoryMock.CheckOwnership mock is already set by Expect") + } + + if mmCheckOwnership.defaultExpectation.paramPtrs == nil { + mmCheckOwnership.defaultExpectation.paramPtrs = &RequestRepositoryMockCheckOwnershipParamPtrs{} + } + mmCheckOwnership.defaultExpectation.paramPtrs.requestID = &requestID + mmCheckOwnership.defaultExpectation.expectationOrigins.originRequestID = minimock.CallerInfo(1) + + return mmCheckOwnership +} + +// ExpectUserIDParam3 sets up expected param userID for RequestRepository.CheckOwnership +func (mmCheckOwnership *mRequestRepositoryMockCheckOwnership) ExpectUserIDParam3(userID int) *mRequestRepositoryMockCheckOwnership { + if mmCheckOwnership.mock.funcCheckOwnership != nil { + mmCheckOwnership.mock.t.Fatalf("RequestRepositoryMock.CheckOwnership mock is already set by Set") + } + + if mmCheckOwnership.defaultExpectation == nil { + mmCheckOwnership.defaultExpectation = &RequestRepositoryMockCheckOwnershipExpectation{} + } + + if mmCheckOwnership.defaultExpectation.params != nil { + mmCheckOwnership.mock.t.Fatalf("RequestRepositoryMock.CheckOwnership mock is already set by Expect") + } + + if mmCheckOwnership.defaultExpectation.paramPtrs == nil { + mmCheckOwnership.defaultExpectation.paramPtrs = &RequestRepositoryMockCheckOwnershipParamPtrs{} + } + mmCheckOwnership.defaultExpectation.paramPtrs.userID = &userID + mmCheckOwnership.defaultExpectation.expectationOrigins.originUserID = minimock.CallerInfo(1) + + return mmCheckOwnership +} + +// Inspect accepts an inspector function that has same arguments as the RequestRepository.CheckOwnership +func (mmCheckOwnership *mRequestRepositoryMockCheckOwnership) Inspect(f func(ctx context.Context, requestID uuid.UUID, userID int)) *mRequestRepositoryMockCheckOwnership { + if mmCheckOwnership.mock.inspectFuncCheckOwnership != nil { + mmCheckOwnership.mock.t.Fatalf("Inspect function is already set for RequestRepositoryMock.CheckOwnership") + } + + mmCheckOwnership.mock.inspectFuncCheckOwnership = f + + return mmCheckOwnership +} + +// Return sets up results that will be returned by RequestRepository.CheckOwnership +func (mmCheckOwnership *mRequestRepositoryMockCheckOwnership) Return(b1 bool, err error) *RequestRepositoryMock { + if mmCheckOwnership.mock.funcCheckOwnership != nil { + mmCheckOwnership.mock.t.Fatalf("RequestRepositoryMock.CheckOwnership mock is already set by Set") + } + + if mmCheckOwnership.defaultExpectation == nil { + mmCheckOwnership.defaultExpectation = &RequestRepositoryMockCheckOwnershipExpectation{mock: mmCheckOwnership.mock} + } + mmCheckOwnership.defaultExpectation.results = &RequestRepositoryMockCheckOwnershipResults{b1, err} + mmCheckOwnership.defaultExpectation.returnOrigin = minimock.CallerInfo(1) + return mmCheckOwnership.mock +} + +// Set uses given function f to mock the RequestRepository.CheckOwnership method +func (mmCheckOwnership *mRequestRepositoryMockCheckOwnership) Set(f func(ctx context.Context, requestID uuid.UUID, userID int) (b1 bool, err error)) *RequestRepositoryMock { + if mmCheckOwnership.defaultExpectation != nil { + mmCheckOwnership.mock.t.Fatalf("Default expectation is already set for the RequestRepository.CheckOwnership method") + } + + if len(mmCheckOwnership.expectations) > 0 { + mmCheckOwnership.mock.t.Fatalf("Some expectations are already set for the RequestRepository.CheckOwnership method") + } + + mmCheckOwnership.mock.funcCheckOwnership = f + mmCheckOwnership.mock.funcCheckOwnershipOrigin = minimock.CallerInfo(1) + return mmCheckOwnership.mock +} + +// When sets expectation for the RequestRepository.CheckOwnership which will trigger the result defined by the following +// Then helper +func (mmCheckOwnership *mRequestRepositoryMockCheckOwnership) When(ctx context.Context, requestID uuid.UUID, userID int) *RequestRepositoryMockCheckOwnershipExpectation { + if mmCheckOwnership.mock.funcCheckOwnership != nil { + mmCheckOwnership.mock.t.Fatalf("RequestRepositoryMock.CheckOwnership mock is already set by Set") + } + + expectation := &RequestRepositoryMockCheckOwnershipExpectation{ + mock: mmCheckOwnership.mock, + params: &RequestRepositoryMockCheckOwnershipParams{ctx, requestID, userID}, + expectationOrigins: RequestRepositoryMockCheckOwnershipExpectationOrigins{origin: minimock.CallerInfo(1)}, + } + mmCheckOwnership.expectations = append(mmCheckOwnership.expectations, expectation) + return expectation +} + +// Then sets up RequestRepository.CheckOwnership return parameters for the expectation previously defined by the When method +func (e *RequestRepositoryMockCheckOwnershipExpectation) Then(b1 bool, err error) *RequestRepositoryMock { + e.results = &RequestRepositoryMockCheckOwnershipResults{b1, err} + return e.mock +} + +// Times sets number of times RequestRepository.CheckOwnership should be invoked +func (mmCheckOwnership *mRequestRepositoryMockCheckOwnership) Times(n uint64) *mRequestRepositoryMockCheckOwnership { + if n == 0 { + mmCheckOwnership.mock.t.Fatalf("Times of RequestRepositoryMock.CheckOwnership mock can not be zero") + } + mm_atomic.StoreUint64(&mmCheckOwnership.expectedInvocations, n) + mmCheckOwnership.expectedInvocationsOrigin = minimock.CallerInfo(1) + return mmCheckOwnership +} + +func (mmCheckOwnership *mRequestRepositoryMockCheckOwnership) invocationsDone() bool { + if len(mmCheckOwnership.expectations) == 0 && mmCheckOwnership.defaultExpectation == nil && mmCheckOwnership.mock.funcCheckOwnership == nil { + return true + } + + totalInvocations := mm_atomic.LoadUint64(&mmCheckOwnership.mock.afterCheckOwnershipCounter) + expectedInvocations := mm_atomic.LoadUint64(&mmCheckOwnership.expectedInvocations) + + return totalInvocations > 0 && (expectedInvocations == 0 || expectedInvocations == totalInvocations) +} + +// CheckOwnership implements mm_repository.RequestRepository +func (mmCheckOwnership *RequestRepositoryMock) CheckOwnership(ctx context.Context, requestID uuid.UUID, userID int) (b1 bool, err error) { + mm_atomic.AddUint64(&mmCheckOwnership.beforeCheckOwnershipCounter, 1) + defer mm_atomic.AddUint64(&mmCheckOwnership.afterCheckOwnershipCounter, 1) + + mmCheckOwnership.t.Helper() + + if mmCheckOwnership.inspectFuncCheckOwnership != nil { + mmCheckOwnership.inspectFuncCheckOwnership(ctx, requestID, userID) + } + + mm_params := RequestRepositoryMockCheckOwnershipParams{ctx, requestID, userID} + + // Record call args + mmCheckOwnership.CheckOwnershipMock.mutex.Lock() + mmCheckOwnership.CheckOwnershipMock.callArgs = append(mmCheckOwnership.CheckOwnershipMock.callArgs, &mm_params) + mmCheckOwnership.CheckOwnershipMock.mutex.Unlock() + + for _, e := range mmCheckOwnership.CheckOwnershipMock.expectations { + if minimock.Equal(*e.params, mm_params) { + mm_atomic.AddUint64(&e.Counter, 1) + return e.results.b1, e.results.err + } + } + + if mmCheckOwnership.CheckOwnershipMock.defaultExpectation != nil { + mm_atomic.AddUint64(&mmCheckOwnership.CheckOwnershipMock.defaultExpectation.Counter, 1) + mm_want := mmCheckOwnership.CheckOwnershipMock.defaultExpectation.params + mm_want_ptrs := mmCheckOwnership.CheckOwnershipMock.defaultExpectation.paramPtrs + + mm_got := RequestRepositoryMockCheckOwnershipParams{ctx, requestID, userID} + + if mm_want_ptrs != nil { + + if mm_want_ptrs.ctx != nil && !minimock.Equal(*mm_want_ptrs.ctx, mm_got.ctx) { + mmCheckOwnership.t.Errorf("RequestRepositoryMock.CheckOwnership got unexpected parameter ctx, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmCheckOwnership.CheckOwnershipMock.defaultExpectation.expectationOrigins.originCtx, *mm_want_ptrs.ctx, mm_got.ctx, minimock.Diff(*mm_want_ptrs.ctx, mm_got.ctx)) + } + + if mm_want_ptrs.requestID != nil && !minimock.Equal(*mm_want_ptrs.requestID, mm_got.requestID) { + mmCheckOwnership.t.Errorf("RequestRepositoryMock.CheckOwnership got unexpected parameter requestID, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmCheckOwnership.CheckOwnershipMock.defaultExpectation.expectationOrigins.originRequestID, *mm_want_ptrs.requestID, mm_got.requestID, minimock.Diff(*mm_want_ptrs.requestID, mm_got.requestID)) + } + + if mm_want_ptrs.userID != nil && !minimock.Equal(*mm_want_ptrs.userID, mm_got.userID) { + mmCheckOwnership.t.Errorf("RequestRepositoryMock.CheckOwnership got unexpected parameter userID, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmCheckOwnership.CheckOwnershipMock.defaultExpectation.expectationOrigins.originUserID, *mm_want_ptrs.userID, mm_got.userID, minimock.Diff(*mm_want_ptrs.userID, mm_got.userID)) + } + + } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { + mmCheckOwnership.t.Errorf("RequestRepositoryMock.CheckOwnership got unexpected parameters, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmCheckOwnership.CheckOwnershipMock.defaultExpectation.expectationOrigins.origin, *mm_want, mm_got, minimock.Diff(*mm_want, mm_got)) + } + + mm_results := mmCheckOwnership.CheckOwnershipMock.defaultExpectation.results + if mm_results == nil { + mmCheckOwnership.t.Fatal("No results are set for the RequestRepositoryMock.CheckOwnership") + } + return (*mm_results).b1, (*mm_results).err + } + if mmCheckOwnership.funcCheckOwnership != nil { + return mmCheckOwnership.funcCheckOwnership(ctx, requestID, userID) + } + mmCheckOwnership.t.Fatalf("Unexpected call to RequestRepositoryMock.CheckOwnership. %v %v %v", ctx, requestID, userID) + return +} + +// CheckOwnershipAfterCounter returns a count of finished RequestRepositoryMock.CheckOwnership invocations +func (mmCheckOwnership *RequestRepositoryMock) CheckOwnershipAfterCounter() uint64 { + return mm_atomic.LoadUint64(&mmCheckOwnership.afterCheckOwnershipCounter) +} + +// CheckOwnershipBeforeCounter returns a count of RequestRepositoryMock.CheckOwnership invocations +func (mmCheckOwnership *RequestRepositoryMock) CheckOwnershipBeforeCounter() uint64 { + return mm_atomic.LoadUint64(&mmCheckOwnership.beforeCheckOwnershipCounter) +} + +// Calls returns a list of arguments used in each call to RequestRepositoryMock.CheckOwnership. +// The list is in the same order as the calls were made (i.e. recent calls have a higher index) +func (mmCheckOwnership *mRequestRepositoryMockCheckOwnership) Calls() []*RequestRepositoryMockCheckOwnershipParams { + mmCheckOwnership.mutex.RLock() + + argCopy := make([]*RequestRepositoryMockCheckOwnershipParams, len(mmCheckOwnership.callArgs)) + copy(argCopy, mmCheckOwnership.callArgs) + + mmCheckOwnership.mutex.RUnlock() + + return argCopy +} + +// MinimockCheckOwnershipDone returns true if the count of the CheckOwnership invocations corresponds +// the number of defined expectations +func (m *RequestRepositoryMock) MinimockCheckOwnershipDone() bool { + if m.CheckOwnershipMock.optional { + // Optional methods provide '0 or more' call count restriction. + return true + } + + for _, e := range m.CheckOwnershipMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + return false + } + } + + return m.CheckOwnershipMock.invocationsDone() +} + +// MinimockCheckOwnershipInspect logs each unmet expectation +func (m *RequestRepositoryMock) MinimockCheckOwnershipInspect() { + for _, e := range m.CheckOwnershipMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + m.t.Errorf("Expected call to RequestRepositoryMock.CheckOwnership at\n%s with params: %#v", e.expectationOrigins.origin, *e.params) + } + } + + afterCheckOwnershipCounter := mm_atomic.LoadUint64(&m.afterCheckOwnershipCounter) + // if default expectation was set then invocations count should be greater than zero + if m.CheckOwnershipMock.defaultExpectation != nil && afterCheckOwnershipCounter < 1 { + if m.CheckOwnershipMock.defaultExpectation.params == nil { + m.t.Errorf("Expected call to RequestRepositoryMock.CheckOwnership at\n%s", m.CheckOwnershipMock.defaultExpectation.returnOrigin) + } else { + m.t.Errorf("Expected call to RequestRepositoryMock.CheckOwnership at\n%s with params: %#v", m.CheckOwnershipMock.defaultExpectation.expectationOrigins.origin, *m.CheckOwnershipMock.defaultExpectation.params) + } + } + // if func was set then invocations count should be greater than zero + if m.funcCheckOwnership != nil && afterCheckOwnershipCounter < 1 { + m.t.Errorf("Expected call to RequestRepositoryMock.CheckOwnership at\n%s", m.funcCheckOwnershipOrigin) + } + + if !m.CheckOwnershipMock.invocationsDone() && afterCheckOwnershipCounter > 0 { + m.t.Errorf("Expected %d calls to RequestRepositoryMock.CheckOwnership at\n%s but found %d calls", + mm_atomic.LoadUint64(&m.CheckOwnershipMock.expectedInvocations), m.CheckOwnershipMock.expectedInvocationsOrigin, afterCheckOwnershipCounter) + } +} + type mRequestRepositoryMockCreate struct { optional bool mock *RequestRepositoryMock @@ -3047,6 +3431,8 @@ func (m *RequestRepositoryMock) MinimockUpdateWithTZTxInspect() { func (m *RequestRepositoryMock) MinimockFinish() { m.finishOnce.Do(func() { if !m.minimockDone() { + m.MinimockCheckOwnershipInspect() + m.MinimockCreateInspect() m.MinimockGetByIDInspect() @@ -3085,6 +3471,7 @@ func (m *RequestRepositoryMock) MinimockWait(timeout mm_time.Duration) { func (m *RequestRepositoryMock) minimockDone() bool { done := true return done && + m.MinimockCheckOwnershipDone() && m.MinimockCreateDone() && m.MinimockGetByIDDone() && m.MinimockGetByUserIDDone() && diff --git a/internal/mocks/request_service_mock.go b/internal/mocks/request_service_mock.go index 12ad5df..4414708 100644 --- a/internal/mocks/request_service_mock.go +++ b/internal/mocks/request_service_mock.go @@ -27,9 +27,9 @@ type RequestServiceMock struct { beforeApproveTZCounter uint64 ApproveTZMock mRequestServiceMockApproveTZ - funcCreateTZ func(ctx context.Context, userID int, requestTxt string) (u1 uuid.UUID, s1 string, err error) + funcCreateTZ func(ctx context.Context, userID int, requestTxt string, fileData []byte, fileName string) (u1 uuid.UUID, s1 string, err error) funcCreateTZOrigin string - inspectFuncCreateTZ func(ctx context.Context, userID int, requestTxt string) + inspectFuncCreateTZ func(ctx context.Context, userID int, requestTxt string, fileData []byte, fileName string) afterCreateTZCounter uint64 beforeCreateTZCounter uint64 CreateTZMock mRequestServiceMockCreateTZ @@ -41,9 +41,9 @@ type RequestServiceMock struct { beforeGetMailingListCounter uint64 GetMailingListMock mRequestServiceMockGetMailingList - funcGetMailingListByID func(ctx context.Context, requestID uuid.UUID) (rp1 *model.RequestDetail, err error) + funcGetMailingListByID func(ctx context.Context, requestID uuid.UUID, userID int) (rp1 *model.RequestDetail, err error) funcGetMailingListByIDOrigin string - inspectFuncGetMailingListByID func(ctx context.Context, requestID uuid.UUID) + inspectFuncGetMailingListByID func(ctx context.Context, requestID uuid.UUID, userID int) afterGetMailingListByIDCounter uint64 beforeGetMailingListByIDCounter uint64 GetMailingListByIDMock mRequestServiceMockGetMailingListByID @@ -508,6 +508,8 @@ type RequestServiceMockCreateTZParams struct { ctx context.Context userID int requestTxt string + fileData []byte + fileName string } // RequestServiceMockCreateTZParamPtrs contains pointers to parameters of the RequestService.CreateTZ @@ -515,6 +517,8 @@ type RequestServiceMockCreateTZParamPtrs struct { ctx *context.Context userID *int requestTxt *string + fileData *[]byte + fileName *string } // RequestServiceMockCreateTZResults contains results of the RequestService.CreateTZ @@ -530,6 +534,8 @@ type RequestServiceMockCreateTZExpectationOrigins struct { originCtx string originUserID string originRequestTxt string + originFileData string + originFileName string } // Marks this method to be optional. The default behavior of any method with Return() is '1 or more', meaning @@ -543,7 +549,7 @@ func (mmCreateTZ *mRequestServiceMockCreateTZ) Optional() *mRequestServiceMockCr } // Expect sets up expected params for RequestService.CreateTZ -func (mmCreateTZ *mRequestServiceMockCreateTZ) Expect(ctx context.Context, userID int, requestTxt string) *mRequestServiceMockCreateTZ { +func (mmCreateTZ *mRequestServiceMockCreateTZ) Expect(ctx context.Context, userID int, requestTxt string, fileData []byte, fileName string) *mRequestServiceMockCreateTZ { if mmCreateTZ.mock.funcCreateTZ != nil { mmCreateTZ.mock.t.Fatalf("RequestServiceMock.CreateTZ mock is already set by Set") } @@ -556,7 +562,7 @@ func (mmCreateTZ *mRequestServiceMockCreateTZ) Expect(ctx context.Context, userI mmCreateTZ.mock.t.Fatalf("RequestServiceMock.CreateTZ mock is already set by ExpectParams functions") } - mmCreateTZ.defaultExpectation.params = &RequestServiceMockCreateTZParams{ctx, userID, requestTxt} + mmCreateTZ.defaultExpectation.params = &RequestServiceMockCreateTZParams{ctx, userID, requestTxt, fileData, fileName} mmCreateTZ.defaultExpectation.expectationOrigins.origin = minimock.CallerInfo(1) for _, e := range mmCreateTZ.expectations { if minimock.Equal(e.params, mmCreateTZ.defaultExpectation.params) { @@ -636,8 +642,54 @@ func (mmCreateTZ *mRequestServiceMockCreateTZ) ExpectRequestTxtParam3(requestTxt return mmCreateTZ } +// ExpectFileDataParam4 sets up expected param fileData for RequestService.CreateTZ +func (mmCreateTZ *mRequestServiceMockCreateTZ) ExpectFileDataParam4(fileData []byte) *mRequestServiceMockCreateTZ { + if mmCreateTZ.mock.funcCreateTZ != nil { + mmCreateTZ.mock.t.Fatalf("RequestServiceMock.CreateTZ mock is already set by Set") + } + + if mmCreateTZ.defaultExpectation == nil { + mmCreateTZ.defaultExpectation = &RequestServiceMockCreateTZExpectation{} + } + + if mmCreateTZ.defaultExpectation.params != nil { + mmCreateTZ.mock.t.Fatalf("RequestServiceMock.CreateTZ mock is already set by Expect") + } + + if mmCreateTZ.defaultExpectation.paramPtrs == nil { + mmCreateTZ.defaultExpectation.paramPtrs = &RequestServiceMockCreateTZParamPtrs{} + } + mmCreateTZ.defaultExpectation.paramPtrs.fileData = &fileData + mmCreateTZ.defaultExpectation.expectationOrigins.originFileData = minimock.CallerInfo(1) + + return mmCreateTZ +} + +// ExpectFileNameParam5 sets up expected param fileName for RequestService.CreateTZ +func (mmCreateTZ *mRequestServiceMockCreateTZ) ExpectFileNameParam5(fileName string) *mRequestServiceMockCreateTZ { + if mmCreateTZ.mock.funcCreateTZ != nil { + mmCreateTZ.mock.t.Fatalf("RequestServiceMock.CreateTZ mock is already set by Set") + } + + if mmCreateTZ.defaultExpectation == nil { + mmCreateTZ.defaultExpectation = &RequestServiceMockCreateTZExpectation{} + } + + if mmCreateTZ.defaultExpectation.params != nil { + mmCreateTZ.mock.t.Fatalf("RequestServiceMock.CreateTZ mock is already set by Expect") + } + + if mmCreateTZ.defaultExpectation.paramPtrs == nil { + mmCreateTZ.defaultExpectation.paramPtrs = &RequestServiceMockCreateTZParamPtrs{} + } + mmCreateTZ.defaultExpectation.paramPtrs.fileName = &fileName + mmCreateTZ.defaultExpectation.expectationOrigins.originFileName = minimock.CallerInfo(1) + + return mmCreateTZ +} + // Inspect accepts an inspector function that has same arguments as the RequestService.CreateTZ -func (mmCreateTZ *mRequestServiceMockCreateTZ) Inspect(f func(ctx context.Context, userID int, requestTxt string)) *mRequestServiceMockCreateTZ { +func (mmCreateTZ *mRequestServiceMockCreateTZ) Inspect(f func(ctx context.Context, userID int, requestTxt string, fileData []byte, fileName string)) *mRequestServiceMockCreateTZ { if mmCreateTZ.mock.inspectFuncCreateTZ != nil { mmCreateTZ.mock.t.Fatalf("Inspect function is already set for RequestServiceMock.CreateTZ") } @@ -662,7 +714,7 @@ func (mmCreateTZ *mRequestServiceMockCreateTZ) Return(u1 uuid.UUID, s1 string, e } // Set uses given function f to mock the RequestService.CreateTZ method -func (mmCreateTZ *mRequestServiceMockCreateTZ) Set(f func(ctx context.Context, userID int, requestTxt string) (u1 uuid.UUID, s1 string, err error)) *RequestServiceMock { +func (mmCreateTZ *mRequestServiceMockCreateTZ) Set(f func(ctx context.Context, userID int, requestTxt string, fileData []byte, fileName string) (u1 uuid.UUID, s1 string, err error)) *RequestServiceMock { if mmCreateTZ.defaultExpectation != nil { mmCreateTZ.mock.t.Fatalf("Default expectation is already set for the RequestService.CreateTZ method") } @@ -678,14 +730,14 @@ func (mmCreateTZ *mRequestServiceMockCreateTZ) Set(f func(ctx context.Context, u // When sets expectation for the RequestService.CreateTZ which will trigger the result defined by the following // Then helper -func (mmCreateTZ *mRequestServiceMockCreateTZ) When(ctx context.Context, userID int, requestTxt string) *RequestServiceMockCreateTZExpectation { +func (mmCreateTZ *mRequestServiceMockCreateTZ) When(ctx context.Context, userID int, requestTxt string, fileData []byte, fileName string) *RequestServiceMockCreateTZExpectation { if mmCreateTZ.mock.funcCreateTZ != nil { mmCreateTZ.mock.t.Fatalf("RequestServiceMock.CreateTZ mock is already set by Set") } expectation := &RequestServiceMockCreateTZExpectation{ mock: mmCreateTZ.mock, - params: &RequestServiceMockCreateTZParams{ctx, userID, requestTxt}, + params: &RequestServiceMockCreateTZParams{ctx, userID, requestTxt, fileData, fileName}, expectationOrigins: RequestServiceMockCreateTZExpectationOrigins{origin: minimock.CallerInfo(1)}, } mmCreateTZ.expectations = append(mmCreateTZ.expectations, expectation) @@ -720,17 +772,17 @@ func (mmCreateTZ *mRequestServiceMockCreateTZ) invocationsDone() bool { } // CreateTZ implements mm_service.RequestService -func (mmCreateTZ *RequestServiceMock) CreateTZ(ctx context.Context, userID int, requestTxt string) (u1 uuid.UUID, s1 string, err error) { +func (mmCreateTZ *RequestServiceMock) CreateTZ(ctx context.Context, userID int, requestTxt string, fileData []byte, fileName string) (u1 uuid.UUID, s1 string, err error) { mm_atomic.AddUint64(&mmCreateTZ.beforeCreateTZCounter, 1) defer mm_atomic.AddUint64(&mmCreateTZ.afterCreateTZCounter, 1) mmCreateTZ.t.Helper() if mmCreateTZ.inspectFuncCreateTZ != nil { - mmCreateTZ.inspectFuncCreateTZ(ctx, userID, requestTxt) + mmCreateTZ.inspectFuncCreateTZ(ctx, userID, requestTxt, fileData, fileName) } - mm_params := RequestServiceMockCreateTZParams{ctx, userID, requestTxt} + mm_params := RequestServiceMockCreateTZParams{ctx, userID, requestTxt, fileData, fileName} // Record call args mmCreateTZ.CreateTZMock.mutex.Lock() @@ -749,7 +801,7 @@ func (mmCreateTZ *RequestServiceMock) CreateTZ(ctx context.Context, userID int, mm_want := mmCreateTZ.CreateTZMock.defaultExpectation.params mm_want_ptrs := mmCreateTZ.CreateTZMock.defaultExpectation.paramPtrs - mm_got := RequestServiceMockCreateTZParams{ctx, userID, requestTxt} + mm_got := RequestServiceMockCreateTZParams{ctx, userID, requestTxt, fileData, fileName} if mm_want_ptrs != nil { @@ -768,6 +820,16 @@ func (mmCreateTZ *RequestServiceMock) CreateTZ(ctx context.Context, userID int, mmCreateTZ.CreateTZMock.defaultExpectation.expectationOrigins.originRequestTxt, *mm_want_ptrs.requestTxt, mm_got.requestTxt, minimock.Diff(*mm_want_ptrs.requestTxt, mm_got.requestTxt)) } + if mm_want_ptrs.fileData != nil && !minimock.Equal(*mm_want_ptrs.fileData, mm_got.fileData) { + mmCreateTZ.t.Errorf("RequestServiceMock.CreateTZ got unexpected parameter fileData, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmCreateTZ.CreateTZMock.defaultExpectation.expectationOrigins.originFileData, *mm_want_ptrs.fileData, mm_got.fileData, minimock.Diff(*mm_want_ptrs.fileData, mm_got.fileData)) + } + + if mm_want_ptrs.fileName != nil && !minimock.Equal(*mm_want_ptrs.fileName, mm_got.fileName) { + mmCreateTZ.t.Errorf("RequestServiceMock.CreateTZ got unexpected parameter fileName, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmCreateTZ.CreateTZMock.defaultExpectation.expectationOrigins.originFileName, *mm_want_ptrs.fileName, mm_got.fileName, minimock.Diff(*mm_want_ptrs.fileName, mm_got.fileName)) + } + } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { mmCreateTZ.t.Errorf("RequestServiceMock.CreateTZ got unexpected parameters, expected at\n%s:\nwant: %#v\n got: %#v%s\n", mmCreateTZ.CreateTZMock.defaultExpectation.expectationOrigins.origin, *mm_want, mm_got, minimock.Diff(*mm_want, mm_got)) @@ -780,9 +842,9 @@ func (mmCreateTZ *RequestServiceMock) CreateTZ(ctx context.Context, userID int, return (*mm_results).u1, (*mm_results).s1, (*mm_results).err } if mmCreateTZ.funcCreateTZ != nil { - return mmCreateTZ.funcCreateTZ(ctx, userID, requestTxt) + return mmCreateTZ.funcCreateTZ(ctx, userID, requestTxt, fileData, fileName) } - mmCreateTZ.t.Fatalf("Unexpected call to RequestServiceMock.CreateTZ. %v %v %v", ctx, userID, requestTxt) + mmCreateTZ.t.Fatalf("Unexpected call to RequestServiceMock.CreateTZ. %v %v %v %v %v", ctx, userID, requestTxt, fileData, fileName) return } @@ -1225,12 +1287,14 @@ type RequestServiceMockGetMailingListByIDExpectation struct { type RequestServiceMockGetMailingListByIDParams struct { ctx context.Context requestID uuid.UUID + userID int } // RequestServiceMockGetMailingListByIDParamPtrs contains pointers to parameters of the RequestService.GetMailingListByID type RequestServiceMockGetMailingListByIDParamPtrs struct { ctx *context.Context requestID *uuid.UUID + userID *int } // RequestServiceMockGetMailingListByIDResults contains results of the RequestService.GetMailingListByID @@ -1244,6 +1308,7 @@ type RequestServiceMockGetMailingListByIDExpectationOrigins struct { origin string originCtx string originRequestID string + originUserID string } // Marks this method to be optional. The default behavior of any method with Return() is '1 or more', meaning @@ -1257,7 +1322,7 @@ func (mmGetMailingListByID *mRequestServiceMockGetMailingListByID) Optional() *m } // Expect sets up expected params for RequestService.GetMailingListByID -func (mmGetMailingListByID *mRequestServiceMockGetMailingListByID) Expect(ctx context.Context, requestID uuid.UUID) *mRequestServiceMockGetMailingListByID { +func (mmGetMailingListByID *mRequestServiceMockGetMailingListByID) Expect(ctx context.Context, requestID uuid.UUID, userID int) *mRequestServiceMockGetMailingListByID { if mmGetMailingListByID.mock.funcGetMailingListByID != nil { mmGetMailingListByID.mock.t.Fatalf("RequestServiceMock.GetMailingListByID mock is already set by Set") } @@ -1270,7 +1335,7 @@ func (mmGetMailingListByID *mRequestServiceMockGetMailingListByID) Expect(ctx co mmGetMailingListByID.mock.t.Fatalf("RequestServiceMock.GetMailingListByID mock is already set by ExpectParams functions") } - mmGetMailingListByID.defaultExpectation.params = &RequestServiceMockGetMailingListByIDParams{ctx, requestID} + mmGetMailingListByID.defaultExpectation.params = &RequestServiceMockGetMailingListByIDParams{ctx, requestID, userID} mmGetMailingListByID.defaultExpectation.expectationOrigins.origin = minimock.CallerInfo(1) for _, e := range mmGetMailingListByID.expectations { if minimock.Equal(e.params, mmGetMailingListByID.defaultExpectation.params) { @@ -1327,8 +1392,31 @@ func (mmGetMailingListByID *mRequestServiceMockGetMailingListByID) ExpectRequest return mmGetMailingListByID } +// ExpectUserIDParam3 sets up expected param userID for RequestService.GetMailingListByID +func (mmGetMailingListByID *mRequestServiceMockGetMailingListByID) ExpectUserIDParam3(userID int) *mRequestServiceMockGetMailingListByID { + if mmGetMailingListByID.mock.funcGetMailingListByID != nil { + mmGetMailingListByID.mock.t.Fatalf("RequestServiceMock.GetMailingListByID mock is already set by Set") + } + + if mmGetMailingListByID.defaultExpectation == nil { + mmGetMailingListByID.defaultExpectation = &RequestServiceMockGetMailingListByIDExpectation{} + } + + if mmGetMailingListByID.defaultExpectation.params != nil { + mmGetMailingListByID.mock.t.Fatalf("RequestServiceMock.GetMailingListByID mock is already set by Expect") + } + + if mmGetMailingListByID.defaultExpectation.paramPtrs == nil { + mmGetMailingListByID.defaultExpectation.paramPtrs = &RequestServiceMockGetMailingListByIDParamPtrs{} + } + mmGetMailingListByID.defaultExpectation.paramPtrs.userID = &userID + mmGetMailingListByID.defaultExpectation.expectationOrigins.originUserID = minimock.CallerInfo(1) + + return mmGetMailingListByID +} + // Inspect accepts an inspector function that has same arguments as the RequestService.GetMailingListByID -func (mmGetMailingListByID *mRequestServiceMockGetMailingListByID) Inspect(f func(ctx context.Context, requestID uuid.UUID)) *mRequestServiceMockGetMailingListByID { +func (mmGetMailingListByID *mRequestServiceMockGetMailingListByID) Inspect(f func(ctx context.Context, requestID uuid.UUID, userID int)) *mRequestServiceMockGetMailingListByID { if mmGetMailingListByID.mock.inspectFuncGetMailingListByID != nil { mmGetMailingListByID.mock.t.Fatalf("Inspect function is already set for RequestServiceMock.GetMailingListByID") } @@ -1353,7 +1441,7 @@ func (mmGetMailingListByID *mRequestServiceMockGetMailingListByID) Return(rp1 *m } // Set uses given function f to mock the RequestService.GetMailingListByID method -func (mmGetMailingListByID *mRequestServiceMockGetMailingListByID) Set(f func(ctx context.Context, requestID uuid.UUID) (rp1 *model.RequestDetail, err error)) *RequestServiceMock { +func (mmGetMailingListByID *mRequestServiceMockGetMailingListByID) Set(f func(ctx context.Context, requestID uuid.UUID, userID int) (rp1 *model.RequestDetail, err error)) *RequestServiceMock { if mmGetMailingListByID.defaultExpectation != nil { mmGetMailingListByID.mock.t.Fatalf("Default expectation is already set for the RequestService.GetMailingListByID method") } @@ -1369,14 +1457,14 @@ func (mmGetMailingListByID *mRequestServiceMockGetMailingListByID) Set(f func(ct // When sets expectation for the RequestService.GetMailingListByID which will trigger the result defined by the following // Then helper -func (mmGetMailingListByID *mRequestServiceMockGetMailingListByID) When(ctx context.Context, requestID uuid.UUID) *RequestServiceMockGetMailingListByIDExpectation { +func (mmGetMailingListByID *mRequestServiceMockGetMailingListByID) When(ctx context.Context, requestID uuid.UUID, userID int) *RequestServiceMockGetMailingListByIDExpectation { if mmGetMailingListByID.mock.funcGetMailingListByID != nil { mmGetMailingListByID.mock.t.Fatalf("RequestServiceMock.GetMailingListByID mock is already set by Set") } expectation := &RequestServiceMockGetMailingListByIDExpectation{ mock: mmGetMailingListByID.mock, - params: &RequestServiceMockGetMailingListByIDParams{ctx, requestID}, + params: &RequestServiceMockGetMailingListByIDParams{ctx, requestID, userID}, expectationOrigins: RequestServiceMockGetMailingListByIDExpectationOrigins{origin: minimock.CallerInfo(1)}, } mmGetMailingListByID.expectations = append(mmGetMailingListByID.expectations, expectation) @@ -1411,17 +1499,17 @@ func (mmGetMailingListByID *mRequestServiceMockGetMailingListByID) invocationsDo } // GetMailingListByID implements mm_service.RequestService -func (mmGetMailingListByID *RequestServiceMock) GetMailingListByID(ctx context.Context, requestID uuid.UUID) (rp1 *model.RequestDetail, err error) { +func (mmGetMailingListByID *RequestServiceMock) GetMailingListByID(ctx context.Context, requestID uuid.UUID, userID int) (rp1 *model.RequestDetail, err error) { mm_atomic.AddUint64(&mmGetMailingListByID.beforeGetMailingListByIDCounter, 1) defer mm_atomic.AddUint64(&mmGetMailingListByID.afterGetMailingListByIDCounter, 1) mmGetMailingListByID.t.Helper() if mmGetMailingListByID.inspectFuncGetMailingListByID != nil { - mmGetMailingListByID.inspectFuncGetMailingListByID(ctx, requestID) + mmGetMailingListByID.inspectFuncGetMailingListByID(ctx, requestID, userID) } - mm_params := RequestServiceMockGetMailingListByIDParams{ctx, requestID} + mm_params := RequestServiceMockGetMailingListByIDParams{ctx, requestID, userID} // Record call args mmGetMailingListByID.GetMailingListByIDMock.mutex.Lock() @@ -1440,7 +1528,7 @@ func (mmGetMailingListByID *RequestServiceMock) GetMailingListByID(ctx context.C mm_want := mmGetMailingListByID.GetMailingListByIDMock.defaultExpectation.params mm_want_ptrs := mmGetMailingListByID.GetMailingListByIDMock.defaultExpectation.paramPtrs - mm_got := RequestServiceMockGetMailingListByIDParams{ctx, requestID} + mm_got := RequestServiceMockGetMailingListByIDParams{ctx, requestID, userID} if mm_want_ptrs != nil { @@ -1454,6 +1542,11 @@ func (mmGetMailingListByID *RequestServiceMock) GetMailingListByID(ctx context.C mmGetMailingListByID.GetMailingListByIDMock.defaultExpectation.expectationOrigins.originRequestID, *mm_want_ptrs.requestID, mm_got.requestID, minimock.Diff(*mm_want_ptrs.requestID, mm_got.requestID)) } + if mm_want_ptrs.userID != nil && !minimock.Equal(*mm_want_ptrs.userID, mm_got.userID) { + mmGetMailingListByID.t.Errorf("RequestServiceMock.GetMailingListByID got unexpected parameter userID, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmGetMailingListByID.GetMailingListByIDMock.defaultExpectation.expectationOrigins.originUserID, *mm_want_ptrs.userID, mm_got.userID, minimock.Diff(*mm_want_ptrs.userID, mm_got.userID)) + } + } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { mmGetMailingListByID.t.Errorf("RequestServiceMock.GetMailingListByID got unexpected parameters, expected at\n%s:\nwant: %#v\n got: %#v%s\n", mmGetMailingListByID.GetMailingListByIDMock.defaultExpectation.expectationOrigins.origin, *mm_want, mm_got, minimock.Diff(*mm_want, mm_got)) @@ -1466,9 +1559,9 @@ func (mmGetMailingListByID *RequestServiceMock) GetMailingListByID(ctx context.C return (*mm_results).rp1, (*mm_results).err } if mmGetMailingListByID.funcGetMailingListByID != nil { - return mmGetMailingListByID.funcGetMailingListByID(ctx, requestID) + return mmGetMailingListByID.funcGetMailingListByID(ctx, requestID, userID) } - mmGetMailingListByID.t.Fatalf("Unexpected call to RequestServiceMock.GetMailingListByID. %v %v", ctx, requestID) + mmGetMailingListByID.t.Fatalf("Unexpected call to RequestServiceMock.GetMailingListByID. %v %v %v", ctx, requestID, userID) return } diff --git a/internal/mocks/supplier_service_mock.go b/internal/mocks/supplier_service_mock.go index 3b165f1..74af0ae 100644 --- a/internal/mocks/supplier_service_mock.go +++ b/internal/mocks/supplier_service_mock.go @@ -19,9 +19,9 @@ type SupplierServiceMock struct { t minimock.Tester finishOnce sync.Once - funcExportExcel func(ctx context.Context, requestID uuid.UUID) (ba1 []byte, err error) + funcExportExcel func(ctx context.Context, requestID uuid.UUID, userID int) (ba1 []byte, err error) funcExportExcelOrigin string - inspectFuncExportExcel func(ctx context.Context, requestID uuid.UUID) + inspectFuncExportExcel func(ctx context.Context, requestID uuid.UUID, userID int) afterExportExcelCounter uint64 beforeExportExcelCounter uint64 ExportExcelMock mSupplierServiceMockExportExcel @@ -71,12 +71,14 @@ type SupplierServiceMockExportExcelExpectation struct { type SupplierServiceMockExportExcelParams struct { ctx context.Context requestID uuid.UUID + userID int } // SupplierServiceMockExportExcelParamPtrs contains pointers to parameters of the SupplierService.ExportExcel type SupplierServiceMockExportExcelParamPtrs struct { ctx *context.Context requestID *uuid.UUID + userID *int } // SupplierServiceMockExportExcelResults contains results of the SupplierService.ExportExcel @@ -90,6 +92,7 @@ type SupplierServiceMockExportExcelExpectationOrigins struct { origin string originCtx string originRequestID string + originUserID string } // Marks this method to be optional. The default behavior of any method with Return() is '1 or more', meaning @@ -103,7 +106,7 @@ func (mmExportExcel *mSupplierServiceMockExportExcel) Optional() *mSupplierServi } // Expect sets up expected params for SupplierService.ExportExcel -func (mmExportExcel *mSupplierServiceMockExportExcel) Expect(ctx context.Context, requestID uuid.UUID) *mSupplierServiceMockExportExcel { +func (mmExportExcel *mSupplierServiceMockExportExcel) Expect(ctx context.Context, requestID uuid.UUID, userID int) *mSupplierServiceMockExportExcel { if mmExportExcel.mock.funcExportExcel != nil { mmExportExcel.mock.t.Fatalf("SupplierServiceMock.ExportExcel mock is already set by Set") } @@ -116,7 +119,7 @@ func (mmExportExcel *mSupplierServiceMockExportExcel) Expect(ctx context.Context mmExportExcel.mock.t.Fatalf("SupplierServiceMock.ExportExcel mock is already set by ExpectParams functions") } - mmExportExcel.defaultExpectation.params = &SupplierServiceMockExportExcelParams{ctx, requestID} + mmExportExcel.defaultExpectation.params = &SupplierServiceMockExportExcelParams{ctx, requestID, userID} mmExportExcel.defaultExpectation.expectationOrigins.origin = minimock.CallerInfo(1) for _, e := range mmExportExcel.expectations { if minimock.Equal(e.params, mmExportExcel.defaultExpectation.params) { @@ -173,8 +176,31 @@ func (mmExportExcel *mSupplierServiceMockExportExcel) ExpectRequestIDParam2(requ return mmExportExcel } +// ExpectUserIDParam3 sets up expected param userID for SupplierService.ExportExcel +func (mmExportExcel *mSupplierServiceMockExportExcel) ExpectUserIDParam3(userID int) *mSupplierServiceMockExportExcel { + if mmExportExcel.mock.funcExportExcel != nil { + mmExportExcel.mock.t.Fatalf("SupplierServiceMock.ExportExcel mock is already set by Set") + } + + if mmExportExcel.defaultExpectation == nil { + mmExportExcel.defaultExpectation = &SupplierServiceMockExportExcelExpectation{} + } + + if mmExportExcel.defaultExpectation.params != nil { + mmExportExcel.mock.t.Fatalf("SupplierServiceMock.ExportExcel mock is already set by Expect") + } + + if mmExportExcel.defaultExpectation.paramPtrs == nil { + mmExportExcel.defaultExpectation.paramPtrs = &SupplierServiceMockExportExcelParamPtrs{} + } + mmExportExcel.defaultExpectation.paramPtrs.userID = &userID + mmExportExcel.defaultExpectation.expectationOrigins.originUserID = minimock.CallerInfo(1) + + return mmExportExcel +} + // Inspect accepts an inspector function that has same arguments as the SupplierService.ExportExcel -func (mmExportExcel *mSupplierServiceMockExportExcel) Inspect(f func(ctx context.Context, requestID uuid.UUID)) *mSupplierServiceMockExportExcel { +func (mmExportExcel *mSupplierServiceMockExportExcel) Inspect(f func(ctx context.Context, requestID uuid.UUID, userID int)) *mSupplierServiceMockExportExcel { if mmExportExcel.mock.inspectFuncExportExcel != nil { mmExportExcel.mock.t.Fatalf("Inspect function is already set for SupplierServiceMock.ExportExcel") } @@ -199,7 +225,7 @@ func (mmExportExcel *mSupplierServiceMockExportExcel) Return(ba1 []byte, err err } // Set uses given function f to mock the SupplierService.ExportExcel method -func (mmExportExcel *mSupplierServiceMockExportExcel) Set(f func(ctx context.Context, requestID uuid.UUID) (ba1 []byte, err error)) *SupplierServiceMock { +func (mmExportExcel *mSupplierServiceMockExportExcel) Set(f func(ctx context.Context, requestID uuid.UUID, userID int) (ba1 []byte, err error)) *SupplierServiceMock { if mmExportExcel.defaultExpectation != nil { mmExportExcel.mock.t.Fatalf("Default expectation is already set for the SupplierService.ExportExcel method") } @@ -215,14 +241,14 @@ func (mmExportExcel *mSupplierServiceMockExportExcel) Set(f func(ctx context.Con // When sets expectation for the SupplierService.ExportExcel which will trigger the result defined by the following // Then helper -func (mmExportExcel *mSupplierServiceMockExportExcel) When(ctx context.Context, requestID uuid.UUID) *SupplierServiceMockExportExcelExpectation { +func (mmExportExcel *mSupplierServiceMockExportExcel) When(ctx context.Context, requestID uuid.UUID, userID int) *SupplierServiceMockExportExcelExpectation { if mmExportExcel.mock.funcExportExcel != nil { mmExportExcel.mock.t.Fatalf("SupplierServiceMock.ExportExcel mock is already set by Set") } expectation := &SupplierServiceMockExportExcelExpectation{ mock: mmExportExcel.mock, - params: &SupplierServiceMockExportExcelParams{ctx, requestID}, + params: &SupplierServiceMockExportExcelParams{ctx, requestID, userID}, expectationOrigins: SupplierServiceMockExportExcelExpectationOrigins{origin: minimock.CallerInfo(1)}, } mmExportExcel.expectations = append(mmExportExcel.expectations, expectation) @@ -257,17 +283,17 @@ func (mmExportExcel *mSupplierServiceMockExportExcel) invocationsDone() bool { } // ExportExcel implements mm_service.SupplierService -func (mmExportExcel *SupplierServiceMock) ExportExcel(ctx context.Context, requestID uuid.UUID) (ba1 []byte, err error) { +func (mmExportExcel *SupplierServiceMock) ExportExcel(ctx context.Context, requestID uuid.UUID, userID int) (ba1 []byte, err error) { mm_atomic.AddUint64(&mmExportExcel.beforeExportExcelCounter, 1) defer mm_atomic.AddUint64(&mmExportExcel.afterExportExcelCounter, 1) mmExportExcel.t.Helper() if mmExportExcel.inspectFuncExportExcel != nil { - mmExportExcel.inspectFuncExportExcel(ctx, requestID) + mmExportExcel.inspectFuncExportExcel(ctx, requestID, userID) } - mm_params := SupplierServiceMockExportExcelParams{ctx, requestID} + mm_params := SupplierServiceMockExportExcelParams{ctx, requestID, userID} // Record call args mmExportExcel.ExportExcelMock.mutex.Lock() @@ -286,7 +312,7 @@ func (mmExportExcel *SupplierServiceMock) ExportExcel(ctx context.Context, reque mm_want := mmExportExcel.ExportExcelMock.defaultExpectation.params mm_want_ptrs := mmExportExcel.ExportExcelMock.defaultExpectation.paramPtrs - mm_got := SupplierServiceMockExportExcelParams{ctx, requestID} + mm_got := SupplierServiceMockExportExcelParams{ctx, requestID, userID} if mm_want_ptrs != nil { @@ -300,6 +326,11 @@ func (mmExportExcel *SupplierServiceMock) ExportExcel(ctx context.Context, reque mmExportExcel.ExportExcelMock.defaultExpectation.expectationOrigins.originRequestID, *mm_want_ptrs.requestID, mm_got.requestID, minimock.Diff(*mm_want_ptrs.requestID, mm_got.requestID)) } + if mm_want_ptrs.userID != nil && !minimock.Equal(*mm_want_ptrs.userID, mm_got.userID) { + mmExportExcel.t.Errorf("SupplierServiceMock.ExportExcel got unexpected parameter userID, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmExportExcel.ExportExcelMock.defaultExpectation.expectationOrigins.originUserID, *mm_want_ptrs.userID, mm_got.userID, minimock.Diff(*mm_want_ptrs.userID, mm_got.userID)) + } + } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { mmExportExcel.t.Errorf("SupplierServiceMock.ExportExcel got unexpected parameters, expected at\n%s:\nwant: %#v\n got: %#v%s\n", mmExportExcel.ExportExcelMock.defaultExpectation.expectationOrigins.origin, *mm_want, mm_got, minimock.Diff(*mm_want, mm_got)) @@ -312,9 +343,9 @@ func (mmExportExcel *SupplierServiceMock) ExportExcel(ctx context.Context, reque return (*mm_results).ba1, (*mm_results).err } if mmExportExcel.funcExportExcel != nil { - return mmExportExcel.funcExportExcel(ctx, requestID) + return mmExportExcel.funcExportExcel(ctx, requestID, userID) } - mmExportExcel.t.Fatalf("Unexpected call to SupplierServiceMock.ExportExcel. %v %v", ctx, requestID) + mmExportExcel.t.Fatalf("Unexpected call to SupplierServiceMock.ExportExcel. %v %v %v", ctx, requestID, userID) return } diff --git a/internal/mocks/token_usage_repository_mock.go b/internal/mocks/token_usage_repository_mock.go index 779b8f3..88f1d84 100644 --- a/internal/mocks/token_usage_repository_mock.go +++ b/internal/mocks/token_usage_repository_mock.go @@ -33,6 +33,13 @@ type TokenUsageRepositoryMock struct { afterCreateTxCounter uint64 beforeCreateTxCounter uint64 CreateTxMock mTokenUsageRepositoryMockCreateTx + + funcGetBalanceStatistics func(ctx context.Context, userID int) (averageCost float64, history []*model.WriteOffHistory, err error) + funcGetBalanceStatisticsOrigin string + inspectFuncGetBalanceStatistics func(ctx context.Context, userID int) + afterGetBalanceStatisticsCounter uint64 + beforeGetBalanceStatisticsCounter uint64 + GetBalanceStatisticsMock mTokenUsageRepositoryMockGetBalanceStatistics } // NewTokenUsageRepositoryMock returns a mock for mm_repository.TokenUsageRepository @@ -49,6 +56,9 @@ func NewTokenUsageRepositoryMock(t minimock.Tester) *TokenUsageRepositoryMock { m.CreateTxMock = mTokenUsageRepositoryMockCreateTx{mock: m} m.CreateTxMock.callArgs = []*TokenUsageRepositoryMockCreateTxParams{} + m.GetBalanceStatisticsMock = mTokenUsageRepositoryMockGetBalanceStatistics{mock: m} + m.GetBalanceStatisticsMock.callArgs = []*TokenUsageRepositoryMockGetBalanceStatisticsParams{} + t.Cleanup(m.MinimockFinish) return m @@ -769,6 +779,350 @@ func (m *TokenUsageRepositoryMock) MinimockCreateTxInspect() { } } +type mTokenUsageRepositoryMockGetBalanceStatistics struct { + optional bool + mock *TokenUsageRepositoryMock + defaultExpectation *TokenUsageRepositoryMockGetBalanceStatisticsExpectation + expectations []*TokenUsageRepositoryMockGetBalanceStatisticsExpectation + + callArgs []*TokenUsageRepositoryMockGetBalanceStatisticsParams + mutex sync.RWMutex + + expectedInvocations uint64 + expectedInvocationsOrigin string +} + +// TokenUsageRepositoryMockGetBalanceStatisticsExpectation specifies expectation struct of the TokenUsageRepository.GetBalanceStatistics +type TokenUsageRepositoryMockGetBalanceStatisticsExpectation struct { + mock *TokenUsageRepositoryMock + params *TokenUsageRepositoryMockGetBalanceStatisticsParams + paramPtrs *TokenUsageRepositoryMockGetBalanceStatisticsParamPtrs + expectationOrigins TokenUsageRepositoryMockGetBalanceStatisticsExpectationOrigins + results *TokenUsageRepositoryMockGetBalanceStatisticsResults + returnOrigin string + Counter uint64 +} + +// TokenUsageRepositoryMockGetBalanceStatisticsParams contains parameters of the TokenUsageRepository.GetBalanceStatistics +type TokenUsageRepositoryMockGetBalanceStatisticsParams struct { + ctx context.Context + userID int +} + +// TokenUsageRepositoryMockGetBalanceStatisticsParamPtrs contains pointers to parameters of the TokenUsageRepository.GetBalanceStatistics +type TokenUsageRepositoryMockGetBalanceStatisticsParamPtrs struct { + ctx *context.Context + userID *int +} + +// TokenUsageRepositoryMockGetBalanceStatisticsResults contains results of the TokenUsageRepository.GetBalanceStatistics +type TokenUsageRepositoryMockGetBalanceStatisticsResults struct { + averageCost float64 + history []*model.WriteOffHistory + err error +} + +// TokenUsageRepositoryMockGetBalanceStatisticsOrigins contains origins of expectations of the TokenUsageRepository.GetBalanceStatistics +type TokenUsageRepositoryMockGetBalanceStatisticsExpectationOrigins struct { + origin string + originCtx string + originUserID string +} + +// Marks this method to be optional. The default behavior of any method with Return() is '1 or more', meaning +// the test will fail minimock's automatic final call check if the mocked method was not called at least once. +// Optional() makes method check to work in '0 or more' mode. +// It is NOT RECOMMENDED to use this option unless you really need it, as default behaviour helps to +// catch the problems when the expected method call is totally skipped during test run. +func (mmGetBalanceStatistics *mTokenUsageRepositoryMockGetBalanceStatistics) Optional() *mTokenUsageRepositoryMockGetBalanceStatistics { + mmGetBalanceStatistics.optional = true + return mmGetBalanceStatistics +} + +// Expect sets up expected params for TokenUsageRepository.GetBalanceStatistics +func (mmGetBalanceStatistics *mTokenUsageRepositoryMockGetBalanceStatistics) Expect(ctx context.Context, userID int) *mTokenUsageRepositoryMockGetBalanceStatistics { + if mmGetBalanceStatistics.mock.funcGetBalanceStatistics != nil { + mmGetBalanceStatistics.mock.t.Fatalf("TokenUsageRepositoryMock.GetBalanceStatistics mock is already set by Set") + } + + if mmGetBalanceStatistics.defaultExpectation == nil { + mmGetBalanceStatistics.defaultExpectation = &TokenUsageRepositoryMockGetBalanceStatisticsExpectation{} + } + + if mmGetBalanceStatistics.defaultExpectation.paramPtrs != nil { + mmGetBalanceStatistics.mock.t.Fatalf("TokenUsageRepositoryMock.GetBalanceStatistics mock is already set by ExpectParams functions") + } + + mmGetBalanceStatistics.defaultExpectation.params = &TokenUsageRepositoryMockGetBalanceStatisticsParams{ctx, userID} + mmGetBalanceStatistics.defaultExpectation.expectationOrigins.origin = minimock.CallerInfo(1) + for _, e := range mmGetBalanceStatistics.expectations { + if minimock.Equal(e.params, mmGetBalanceStatistics.defaultExpectation.params) { + mmGetBalanceStatistics.mock.t.Fatalf("Expectation set by When has same params: %#v", *mmGetBalanceStatistics.defaultExpectation.params) + } + } + + return mmGetBalanceStatistics +} + +// ExpectCtxParam1 sets up expected param ctx for TokenUsageRepository.GetBalanceStatistics +func (mmGetBalanceStatistics *mTokenUsageRepositoryMockGetBalanceStatistics) ExpectCtxParam1(ctx context.Context) *mTokenUsageRepositoryMockGetBalanceStatistics { + if mmGetBalanceStatistics.mock.funcGetBalanceStatistics != nil { + mmGetBalanceStatistics.mock.t.Fatalf("TokenUsageRepositoryMock.GetBalanceStatistics mock is already set by Set") + } + + if mmGetBalanceStatistics.defaultExpectation == nil { + mmGetBalanceStatistics.defaultExpectation = &TokenUsageRepositoryMockGetBalanceStatisticsExpectation{} + } + + if mmGetBalanceStatistics.defaultExpectation.params != nil { + mmGetBalanceStatistics.mock.t.Fatalf("TokenUsageRepositoryMock.GetBalanceStatistics mock is already set by Expect") + } + + if mmGetBalanceStatistics.defaultExpectation.paramPtrs == nil { + mmGetBalanceStatistics.defaultExpectation.paramPtrs = &TokenUsageRepositoryMockGetBalanceStatisticsParamPtrs{} + } + mmGetBalanceStatistics.defaultExpectation.paramPtrs.ctx = &ctx + mmGetBalanceStatistics.defaultExpectation.expectationOrigins.originCtx = minimock.CallerInfo(1) + + return mmGetBalanceStatistics +} + +// ExpectUserIDParam2 sets up expected param userID for TokenUsageRepository.GetBalanceStatistics +func (mmGetBalanceStatistics *mTokenUsageRepositoryMockGetBalanceStatistics) ExpectUserIDParam2(userID int) *mTokenUsageRepositoryMockGetBalanceStatistics { + if mmGetBalanceStatistics.mock.funcGetBalanceStatistics != nil { + mmGetBalanceStatistics.mock.t.Fatalf("TokenUsageRepositoryMock.GetBalanceStatistics mock is already set by Set") + } + + if mmGetBalanceStatistics.defaultExpectation == nil { + mmGetBalanceStatistics.defaultExpectation = &TokenUsageRepositoryMockGetBalanceStatisticsExpectation{} + } + + if mmGetBalanceStatistics.defaultExpectation.params != nil { + mmGetBalanceStatistics.mock.t.Fatalf("TokenUsageRepositoryMock.GetBalanceStatistics mock is already set by Expect") + } + + if mmGetBalanceStatistics.defaultExpectation.paramPtrs == nil { + mmGetBalanceStatistics.defaultExpectation.paramPtrs = &TokenUsageRepositoryMockGetBalanceStatisticsParamPtrs{} + } + mmGetBalanceStatistics.defaultExpectation.paramPtrs.userID = &userID + mmGetBalanceStatistics.defaultExpectation.expectationOrigins.originUserID = minimock.CallerInfo(1) + + return mmGetBalanceStatistics +} + +// Inspect accepts an inspector function that has same arguments as the TokenUsageRepository.GetBalanceStatistics +func (mmGetBalanceStatistics *mTokenUsageRepositoryMockGetBalanceStatistics) Inspect(f func(ctx context.Context, userID int)) *mTokenUsageRepositoryMockGetBalanceStatistics { + if mmGetBalanceStatistics.mock.inspectFuncGetBalanceStatistics != nil { + mmGetBalanceStatistics.mock.t.Fatalf("Inspect function is already set for TokenUsageRepositoryMock.GetBalanceStatistics") + } + + mmGetBalanceStatistics.mock.inspectFuncGetBalanceStatistics = f + + return mmGetBalanceStatistics +} + +// Return sets up results that will be returned by TokenUsageRepository.GetBalanceStatistics +func (mmGetBalanceStatistics *mTokenUsageRepositoryMockGetBalanceStatistics) Return(averageCost float64, history []*model.WriteOffHistory, err error) *TokenUsageRepositoryMock { + if mmGetBalanceStatistics.mock.funcGetBalanceStatistics != nil { + mmGetBalanceStatistics.mock.t.Fatalf("TokenUsageRepositoryMock.GetBalanceStatistics mock is already set by Set") + } + + if mmGetBalanceStatistics.defaultExpectation == nil { + mmGetBalanceStatistics.defaultExpectation = &TokenUsageRepositoryMockGetBalanceStatisticsExpectation{mock: mmGetBalanceStatistics.mock} + } + mmGetBalanceStatistics.defaultExpectation.results = &TokenUsageRepositoryMockGetBalanceStatisticsResults{averageCost, history, err} + mmGetBalanceStatistics.defaultExpectation.returnOrigin = minimock.CallerInfo(1) + return mmGetBalanceStatistics.mock +} + +// Set uses given function f to mock the TokenUsageRepository.GetBalanceStatistics method +func (mmGetBalanceStatistics *mTokenUsageRepositoryMockGetBalanceStatistics) Set(f func(ctx context.Context, userID int) (averageCost float64, history []*model.WriteOffHistory, err error)) *TokenUsageRepositoryMock { + if mmGetBalanceStatistics.defaultExpectation != nil { + mmGetBalanceStatistics.mock.t.Fatalf("Default expectation is already set for the TokenUsageRepository.GetBalanceStatistics method") + } + + if len(mmGetBalanceStatistics.expectations) > 0 { + mmGetBalanceStatistics.mock.t.Fatalf("Some expectations are already set for the TokenUsageRepository.GetBalanceStatistics method") + } + + mmGetBalanceStatistics.mock.funcGetBalanceStatistics = f + mmGetBalanceStatistics.mock.funcGetBalanceStatisticsOrigin = minimock.CallerInfo(1) + return mmGetBalanceStatistics.mock +} + +// When sets expectation for the TokenUsageRepository.GetBalanceStatistics which will trigger the result defined by the following +// Then helper +func (mmGetBalanceStatistics *mTokenUsageRepositoryMockGetBalanceStatistics) When(ctx context.Context, userID int) *TokenUsageRepositoryMockGetBalanceStatisticsExpectation { + if mmGetBalanceStatistics.mock.funcGetBalanceStatistics != nil { + mmGetBalanceStatistics.mock.t.Fatalf("TokenUsageRepositoryMock.GetBalanceStatistics mock is already set by Set") + } + + expectation := &TokenUsageRepositoryMockGetBalanceStatisticsExpectation{ + mock: mmGetBalanceStatistics.mock, + params: &TokenUsageRepositoryMockGetBalanceStatisticsParams{ctx, userID}, + expectationOrigins: TokenUsageRepositoryMockGetBalanceStatisticsExpectationOrigins{origin: minimock.CallerInfo(1)}, + } + mmGetBalanceStatistics.expectations = append(mmGetBalanceStatistics.expectations, expectation) + return expectation +} + +// Then sets up TokenUsageRepository.GetBalanceStatistics return parameters for the expectation previously defined by the When method +func (e *TokenUsageRepositoryMockGetBalanceStatisticsExpectation) Then(averageCost float64, history []*model.WriteOffHistory, err error) *TokenUsageRepositoryMock { + e.results = &TokenUsageRepositoryMockGetBalanceStatisticsResults{averageCost, history, err} + return e.mock +} + +// Times sets number of times TokenUsageRepository.GetBalanceStatistics should be invoked +func (mmGetBalanceStatistics *mTokenUsageRepositoryMockGetBalanceStatistics) Times(n uint64) *mTokenUsageRepositoryMockGetBalanceStatistics { + if n == 0 { + mmGetBalanceStatistics.mock.t.Fatalf("Times of TokenUsageRepositoryMock.GetBalanceStatistics mock can not be zero") + } + mm_atomic.StoreUint64(&mmGetBalanceStatistics.expectedInvocations, n) + mmGetBalanceStatistics.expectedInvocationsOrigin = minimock.CallerInfo(1) + return mmGetBalanceStatistics +} + +func (mmGetBalanceStatistics *mTokenUsageRepositoryMockGetBalanceStatistics) invocationsDone() bool { + if len(mmGetBalanceStatistics.expectations) == 0 && mmGetBalanceStatistics.defaultExpectation == nil && mmGetBalanceStatistics.mock.funcGetBalanceStatistics == nil { + return true + } + + totalInvocations := mm_atomic.LoadUint64(&mmGetBalanceStatistics.mock.afterGetBalanceStatisticsCounter) + expectedInvocations := mm_atomic.LoadUint64(&mmGetBalanceStatistics.expectedInvocations) + + return totalInvocations > 0 && (expectedInvocations == 0 || expectedInvocations == totalInvocations) +} + +// GetBalanceStatistics implements mm_repository.TokenUsageRepository +func (mmGetBalanceStatistics *TokenUsageRepositoryMock) GetBalanceStatistics(ctx context.Context, userID int) (averageCost float64, history []*model.WriteOffHistory, err error) { + mm_atomic.AddUint64(&mmGetBalanceStatistics.beforeGetBalanceStatisticsCounter, 1) + defer mm_atomic.AddUint64(&mmGetBalanceStatistics.afterGetBalanceStatisticsCounter, 1) + + mmGetBalanceStatistics.t.Helper() + + if mmGetBalanceStatistics.inspectFuncGetBalanceStatistics != nil { + mmGetBalanceStatistics.inspectFuncGetBalanceStatistics(ctx, userID) + } + + mm_params := TokenUsageRepositoryMockGetBalanceStatisticsParams{ctx, userID} + + // Record call args + mmGetBalanceStatistics.GetBalanceStatisticsMock.mutex.Lock() + mmGetBalanceStatistics.GetBalanceStatisticsMock.callArgs = append(mmGetBalanceStatistics.GetBalanceStatisticsMock.callArgs, &mm_params) + mmGetBalanceStatistics.GetBalanceStatisticsMock.mutex.Unlock() + + for _, e := range mmGetBalanceStatistics.GetBalanceStatisticsMock.expectations { + if minimock.Equal(*e.params, mm_params) { + mm_atomic.AddUint64(&e.Counter, 1) + return e.results.averageCost, e.results.history, e.results.err + } + } + + if mmGetBalanceStatistics.GetBalanceStatisticsMock.defaultExpectation != nil { + mm_atomic.AddUint64(&mmGetBalanceStatistics.GetBalanceStatisticsMock.defaultExpectation.Counter, 1) + mm_want := mmGetBalanceStatistics.GetBalanceStatisticsMock.defaultExpectation.params + mm_want_ptrs := mmGetBalanceStatistics.GetBalanceStatisticsMock.defaultExpectation.paramPtrs + + mm_got := TokenUsageRepositoryMockGetBalanceStatisticsParams{ctx, userID} + + if mm_want_ptrs != nil { + + if mm_want_ptrs.ctx != nil && !minimock.Equal(*mm_want_ptrs.ctx, mm_got.ctx) { + mmGetBalanceStatistics.t.Errorf("TokenUsageRepositoryMock.GetBalanceStatistics got unexpected parameter ctx, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmGetBalanceStatistics.GetBalanceStatisticsMock.defaultExpectation.expectationOrigins.originCtx, *mm_want_ptrs.ctx, mm_got.ctx, minimock.Diff(*mm_want_ptrs.ctx, mm_got.ctx)) + } + + if mm_want_ptrs.userID != nil && !minimock.Equal(*mm_want_ptrs.userID, mm_got.userID) { + mmGetBalanceStatistics.t.Errorf("TokenUsageRepositoryMock.GetBalanceStatistics got unexpected parameter userID, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmGetBalanceStatistics.GetBalanceStatisticsMock.defaultExpectation.expectationOrigins.originUserID, *mm_want_ptrs.userID, mm_got.userID, minimock.Diff(*mm_want_ptrs.userID, mm_got.userID)) + } + + } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { + mmGetBalanceStatistics.t.Errorf("TokenUsageRepositoryMock.GetBalanceStatistics got unexpected parameters, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmGetBalanceStatistics.GetBalanceStatisticsMock.defaultExpectation.expectationOrigins.origin, *mm_want, mm_got, minimock.Diff(*mm_want, mm_got)) + } + + mm_results := mmGetBalanceStatistics.GetBalanceStatisticsMock.defaultExpectation.results + if mm_results == nil { + mmGetBalanceStatistics.t.Fatal("No results are set for the TokenUsageRepositoryMock.GetBalanceStatistics") + } + return (*mm_results).averageCost, (*mm_results).history, (*mm_results).err + } + if mmGetBalanceStatistics.funcGetBalanceStatistics != nil { + return mmGetBalanceStatistics.funcGetBalanceStatistics(ctx, userID) + } + mmGetBalanceStatistics.t.Fatalf("Unexpected call to TokenUsageRepositoryMock.GetBalanceStatistics. %v %v", ctx, userID) + return +} + +// GetBalanceStatisticsAfterCounter returns a count of finished TokenUsageRepositoryMock.GetBalanceStatistics invocations +func (mmGetBalanceStatistics *TokenUsageRepositoryMock) GetBalanceStatisticsAfterCounter() uint64 { + return mm_atomic.LoadUint64(&mmGetBalanceStatistics.afterGetBalanceStatisticsCounter) +} + +// GetBalanceStatisticsBeforeCounter returns a count of TokenUsageRepositoryMock.GetBalanceStatistics invocations +func (mmGetBalanceStatistics *TokenUsageRepositoryMock) GetBalanceStatisticsBeforeCounter() uint64 { + return mm_atomic.LoadUint64(&mmGetBalanceStatistics.beforeGetBalanceStatisticsCounter) +} + +// Calls returns a list of arguments used in each call to TokenUsageRepositoryMock.GetBalanceStatistics. +// The list is in the same order as the calls were made (i.e. recent calls have a higher index) +func (mmGetBalanceStatistics *mTokenUsageRepositoryMockGetBalanceStatistics) Calls() []*TokenUsageRepositoryMockGetBalanceStatisticsParams { + mmGetBalanceStatistics.mutex.RLock() + + argCopy := make([]*TokenUsageRepositoryMockGetBalanceStatisticsParams, len(mmGetBalanceStatistics.callArgs)) + copy(argCopy, mmGetBalanceStatistics.callArgs) + + mmGetBalanceStatistics.mutex.RUnlock() + + return argCopy +} + +// MinimockGetBalanceStatisticsDone returns true if the count of the GetBalanceStatistics invocations corresponds +// the number of defined expectations +func (m *TokenUsageRepositoryMock) MinimockGetBalanceStatisticsDone() bool { + if m.GetBalanceStatisticsMock.optional { + // Optional methods provide '0 or more' call count restriction. + return true + } + + for _, e := range m.GetBalanceStatisticsMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + return false + } + } + + return m.GetBalanceStatisticsMock.invocationsDone() +} + +// MinimockGetBalanceStatisticsInspect logs each unmet expectation +func (m *TokenUsageRepositoryMock) MinimockGetBalanceStatisticsInspect() { + for _, e := range m.GetBalanceStatisticsMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + m.t.Errorf("Expected call to TokenUsageRepositoryMock.GetBalanceStatistics at\n%s with params: %#v", e.expectationOrigins.origin, *e.params) + } + } + + afterGetBalanceStatisticsCounter := mm_atomic.LoadUint64(&m.afterGetBalanceStatisticsCounter) + // if default expectation was set then invocations count should be greater than zero + if m.GetBalanceStatisticsMock.defaultExpectation != nil && afterGetBalanceStatisticsCounter < 1 { + if m.GetBalanceStatisticsMock.defaultExpectation.params == nil { + m.t.Errorf("Expected call to TokenUsageRepositoryMock.GetBalanceStatistics at\n%s", m.GetBalanceStatisticsMock.defaultExpectation.returnOrigin) + } else { + m.t.Errorf("Expected call to TokenUsageRepositoryMock.GetBalanceStatistics at\n%s with params: %#v", m.GetBalanceStatisticsMock.defaultExpectation.expectationOrigins.origin, *m.GetBalanceStatisticsMock.defaultExpectation.params) + } + } + // if func was set then invocations count should be greater than zero + if m.funcGetBalanceStatistics != nil && afterGetBalanceStatisticsCounter < 1 { + m.t.Errorf("Expected call to TokenUsageRepositoryMock.GetBalanceStatistics at\n%s", m.funcGetBalanceStatisticsOrigin) + } + + if !m.GetBalanceStatisticsMock.invocationsDone() && afterGetBalanceStatisticsCounter > 0 { + m.t.Errorf("Expected %d calls to TokenUsageRepositoryMock.GetBalanceStatistics at\n%s but found %d calls", + mm_atomic.LoadUint64(&m.GetBalanceStatisticsMock.expectedInvocations), m.GetBalanceStatisticsMock.expectedInvocationsOrigin, afterGetBalanceStatisticsCounter) + } +} + // MinimockFinish checks that all mocked methods have been called the expected number of times func (m *TokenUsageRepositoryMock) MinimockFinish() { m.finishOnce.Do(func() { @@ -776,6 +1130,8 @@ func (m *TokenUsageRepositoryMock) MinimockFinish() { m.MinimockCreateInspect() m.MinimockCreateTxInspect() + + m.MinimockGetBalanceStatisticsInspect() } }) } @@ -800,5 +1156,6 @@ func (m *TokenUsageRepositoryMock) minimockDone() bool { done := true return done && m.MinimockCreateDone() && - m.MinimockCreateTxDone() + m.MinimockCreateTxDone() && + m.MinimockGetBalanceStatisticsDone() } diff --git a/internal/mocks/user_service_mock.go b/internal/mocks/user_service_mock.go index 1b1a868..5358af4 100644 --- a/internal/mocks/user_service_mock.go +++ b/internal/mocks/user_service_mock.go @@ -26,6 +26,13 @@ type UserServiceMock struct { beforeGetBalanceCounter uint64 GetBalanceMock mUserServiceMockGetBalance + funcGetBalanceStatistics func(ctx context.Context, userID int) (bp1 *mm_service.BalanceStatistics, err error) + funcGetBalanceStatisticsOrigin string + inspectFuncGetBalanceStatistics func(ctx context.Context, userID int) + afterGetBalanceStatisticsCounter uint64 + beforeGetBalanceStatisticsCounter uint64 + GetBalanceStatisticsMock mUserServiceMockGetBalanceStatistics + funcGetInfo func(ctx context.Context, userID int) (up1 *mm_service.UserInfo, err error) funcGetInfoOrigin string inspectFuncGetInfo func(ctx context.Context, userID int) @@ -52,6 +59,9 @@ func NewUserServiceMock(t minimock.Tester) *UserServiceMock { m.GetBalanceMock = mUserServiceMockGetBalance{mock: m} m.GetBalanceMock.callArgs = []*UserServiceMockGetBalanceParams{} + m.GetBalanceStatisticsMock = mUserServiceMockGetBalanceStatistics{mock: m} + m.GetBalanceStatisticsMock.callArgs = []*UserServiceMockGetBalanceStatisticsParams{} + m.GetInfoMock = mUserServiceMockGetInfo{mock: m} m.GetInfoMock.callArgs = []*UserServiceMockGetInfoParams{} @@ -406,6 +416,349 @@ func (m *UserServiceMock) MinimockGetBalanceInspect() { } } +type mUserServiceMockGetBalanceStatistics struct { + optional bool + mock *UserServiceMock + defaultExpectation *UserServiceMockGetBalanceStatisticsExpectation + expectations []*UserServiceMockGetBalanceStatisticsExpectation + + callArgs []*UserServiceMockGetBalanceStatisticsParams + mutex sync.RWMutex + + expectedInvocations uint64 + expectedInvocationsOrigin string +} + +// UserServiceMockGetBalanceStatisticsExpectation specifies expectation struct of the UserService.GetBalanceStatistics +type UserServiceMockGetBalanceStatisticsExpectation struct { + mock *UserServiceMock + params *UserServiceMockGetBalanceStatisticsParams + paramPtrs *UserServiceMockGetBalanceStatisticsParamPtrs + expectationOrigins UserServiceMockGetBalanceStatisticsExpectationOrigins + results *UserServiceMockGetBalanceStatisticsResults + returnOrigin string + Counter uint64 +} + +// UserServiceMockGetBalanceStatisticsParams contains parameters of the UserService.GetBalanceStatistics +type UserServiceMockGetBalanceStatisticsParams struct { + ctx context.Context + userID int +} + +// UserServiceMockGetBalanceStatisticsParamPtrs contains pointers to parameters of the UserService.GetBalanceStatistics +type UserServiceMockGetBalanceStatisticsParamPtrs struct { + ctx *context.Context + userID *int +} + +// UserServiceMockGetBalanceStatisticsResults contains results of the UserService.GetBalanceStatistics +type UserServiceMockGetBalanceStatisticsResults struct { + bp1 *mm_service.BalanceStatistics + err error +} + +// UserServiceMockGetBalanceStatisticsOrigins contains origins of expectations of the UserService.GetBalanceStatistics +type UserServiceMockGetBalanceStatisticsExpectationOrigins struct { + origin string + originCtx string + originUserID string +} + +// Marks this method to be optional. The default behavior of any method with Return() is '1 or more', meaning +// the test will fail minimock's automatic final call check if the mocked method was not called at least once. +// Optional() makes method check to work in '0 or more' mode. +// It is NOT RECOMMENDED to use this option unless you really need it, as default behaviour helps to +// catch the problems when the expected method call is totally skipped during test run. +func (mmGetBalanceStatistics *mUserServiceMockGetBalanceStatistics) Optional() *mUserServiceMockGetBalanceStatistics { + mmGetBalanceStatistics.optional = true + return mmGetBalanceStatistics +} + +// Expect sets up expected params for UserService.GetBalanceStatistics +func (mmGetBalanceStatistics *mUserServiceMockGetBalanceStatistics) Expect(ctx context.Context, userID int) *mUserServiceMockGetBalanceStatistics { + if mmGetBalanceStatistics.mock.funcGetBalanceStatistics != nil { + mmGetBalanceStatistics.mock.t.Fatalf("UserServiceMock.GetBalanceStatistics mock is already set by Set") + } + + if mmGetBalanceStatistics.defaultExpectation == nil { + mmGetBalanceStatistics.defaultExpectation = &UserServiceMockGetBalanceStatisticsExpectation{} + } + + if mmGetBalanceStatistics.defaultExpectation.paramPtrs != nil { + mmGetBalanceStatistics.mock.t.Fatalf("UserServiceMock.GetBalanceStatistics mock is already set by ExpectParams functions") + } + + mmGetBalanceStatistics.defaultExpectation.params = &UserServiceMockGetBalanceStatisticsParams{ctx, userID} + mmGetBalanceStatistics.defaultExpectation.expectationOrigins.origin = minimock.CallerInfo(1) + for _, e := range mmGetBalanceStatistics.expectations { + if minimock.Equal(e.params, mmGetBalanceStatistics.defaultExpectation.params) { + mmGetBalanceStatistics.mock.t.Fatalf("Expectation set by When has same params: %#v", *mmGetBalanceStatistics.defaultExpectation.params) + } + } + + return mmGetBalanceStatistics +} + +// ExpectCtxParam1 sets up expected param ctx for UserService.GetBalanceStatistics +func (mmGetBalanceStatistics *mUserServiceMockGetBalanceStatistics) ExpectCtxParam1(ctx context.Context) *mUserServiceMockGetBalanceStatistics { + if mmGetBalanceStatistics.mock.funcGetBalanceStatistics != nil { + mmGetBalanceStatistics.mock.t.Fatalf("UserServiceMock.GetBalanceStatistics mock is already set by Set") + } + + if mmGetBalanceStatistics.defaultExpectation == nil { + mmGetBalanceStatistics.defaultExpectation = &UserServiceMockGetBalanceStatisticsExpectation{} + } + + if mmGetBalanceStatistics.defaultExpectation.params != nil { + mmGetBalanceStatistics.mock.t.Fatalf("UserServiceMock.GetBalanceStatistics mock is already set by Expect") + } + + if mmGetBalanceStatistics.defaultExpectation.paramPtrs == nil { + mmGetBalanceStatistics.defaultExpectation.paramPtrs = &UserServiceMockGetBalanceStatisticsParamPtrs{} + } + mmGetBalanceStatistics.defaultExpectation.paramPtrs.ctx = &ctx + mmGetBalanceStatistics.defaultExpectation.expectationOrigins.originCtx = minimock.CallerInfo(1) + + return mmGetBalanceStatistics +} + +// ExpectUserIDParam2 sets up expected param userID for UserService.GetBalanceStatistics +func (mmGetBalanceStatistics *mUserServiceMockGetBalanceStatistics) ExpectUserIDParam2(userID int) *mUserServiceMockGetBalanceStatistics { + if mmGetBalanceStatistics.mock.funcGetBalanceStatistics != nil { + mmGetBalanceStatistics.mock.t.Fatalf("UserServiceMock.GetBalanceStatistics mock is already set by Set") + } + + if mmGetBalanceStatistics.defaultExpectation == nil { + mmGetBalanceStatistics.defaultExpectation = &UserServiceMockGetBalanceStatisticsExpectation{} + } + + if mmGetBalanceStatistics.defaultExpectation.params != nil { + mmGetBalanceStatistics.mock.t.Fatalf("UserServiceMock.GetBalanceStatistics mock is already set by Expect") + } + + if mmGetBalanceStatistics.defaultExpectation.paramPtrs == nil { + mmGetBalanceStatistics.defaultExpectation.paramPtrs = &UserServiceMockGetBalanceStatisticsParamPtrs{} + } + mmGetBalanceStatistics.defaultExpectation.paramPtrs.userID = &userID + mmGetBalanceStatistics.defaultExpectation.expectationOrigins.originUserID = minimock.CallerInfo(1) + + return mmGetBalanceStatistics +} + +// Inspect accepts an inspector function that has same arguments as the UserService.GetBalanceStatistics +func (mmGetBalanceStatistics *mUserServiceMockGetBalanceStatistics) Inspect(f func(ctx context.Context, userID int)) *mUserServiceMockGetBalanceStatistics { + if mmGetBalanceStatistics.mock.inspectFuncGetBalanceStatistics != nil { + mmGetBalanceStatistics.mock.t.Fatalf("Inspect function is already set for UserServiceMock.GetBalanceStatistics") + } + + mmGetBalanceStatistics.mock.inspectFuncGetBalanceStatistics = f + + return mmGetBalanceStatistics +} + +// Return sets up results that will be returned by UserService.GetBalanceStatistics +func (mmGetBalanceStatistics *mUserServiceMockGetBalanceStatistics) Return(bp1 *mm_service.BalanceStatistics, err error) *UserServiceMock { + if mmGetBalanceStatistics.mock.funcGetBalanceStatistics != nil { + mmGetBalanceStatistics.mock.t.Fatalf("UserServiceMock.GetBalanceStatistics mock is already set by Set") + } + + if mmGetBalanceStatistics.defaultExpectation == nil { + mmGetBalanceStatistics.defaultExpectation = &UserServiceMockGetBalanceStatisticsExpectation{mock: mmGetBalanceStatistics.mock} + } + mmGetBalanceStatistics.defaultExpectation.results = &UserServiceMockGetBalanceStatisticsResults{bp1, err} + mmGetBalanceStatistics.defaultExpectation.returnOrigin = minimock.CallerInfo(1) + return mmGetBalanceStatistics.mock +} + +// Set uses given function f to mock the UserService.GetBalanceStatistics method +func (mmGetBalanceStatistics *mUserServiceMockGetBalanceStatistics) Set(f func(ctx context.Context, userID int) (bp1 *mm_service.BalanceStatistics, err error)) *UserServiceMock { + if mmGetBalanceStatistics.defaultExpectation != nil { + mmGetBalanceStatistics.mock.t.Fatalf("Default expectation is already set for the UserService.GetBalanceStatistics method") + } + + if len(mmGetBalanceStatistics.expectations) > 0 { + mmGetBalanceStatistics.mock.t.Fatalf("Some expectations are already set for the UserService.GetBalanceStatistics method") + } + + mmGetBalanceStatistics.mock.funcGetBalanceStatistics = f + mmGetBalanceStatistics.mock.funcGetBalanceStatisticsOrigin = minimock.CallerInfo(1) + return mmGetBalanceStatistics.mock +} + +// When sets expectation for the UserService.GetBalanceStatistics which will trigger the result defined by the following +// Then helper +func (mmGetBalanceStatistics *mUserServiceMockGetBalanceStatistics) When(ctx context.Context, userID int) *UserServiceMockGetBalanceStatisticsExpectation { + if mmGetBalanceStatistics.mock.funcGetBalanceStatistics != nil { + mmGetBalanceStatistics.mock.t.Fatalf("UserServiceMock.GetBalanceStatistics mock is already set by Set") + } + + expectation := &UserServiceMockGetBalanceStatisticsExpectation{ + mock: mmGetBalanceStatistics.mock, + params: &UserServiceMockGetBalanceStatisticsParams{ctx, userID}, + expectationOrigins: UserServiceMockGetBalanceStatisticsExpectationOrigins{origin: minimock.CallerInfo(1)}, + } + mmGetBalanceStatistics.expectations = append(mmGetBalanceStatistics.expectations, expectation) + return expectation +} + +// Then sets up UserService.GetBalanceStatistics return parameters for the expectation previously defined by the When method +func (e *UserServiceMockGetBalanceStatisticsExpectation) Then(bp1 *mm_service.BalanceStatistics, err error) *UserServiceMock { + e.results = &UserServiceMockGetBalanceStatisticsResults{bp1, err} + return e.mock +} + +// Times sets number of times UserService.GetBalanceStatistics should be invoked +func (mmGetBalanceStatistics *mUserServiceMockGetBalanceStatistics) Times(n uint64) *mUserServiceMockGetBalanceStatistics { + if n == 0 { + mmGetBalanceStatistics.mock.t.Fatalf("Times of UserServiceMock.GetBalanceStatistics mock can not be zero") + } + mm_atomic.StoreUint64(&mmGetBalanceStatistics.expectedInvocations, n) + mmGetBalanceStatistics.expectedInvocationsOrigin = minimock.CallerInfo(1) + return mmGetBalanceStatistics +} + +func (mmGetBalanceStatistics *mUserServiceMockGetBalanceStatistics) invocationsDone() bool { + if len(mmGetBalanceStatistics.expectations) == 0 && mmGetBalanceStatistics.defaultExpectation == nil && mmGetBalanceStatistics.mock.funcGetBalanceStatistics == nil { + return true + } + + totalInvocations := mm_atomic.LoadUint64(&mmGetBalanceStatistics.mock.afterGetBalanceStatisticsCounter) + expectedInvocations := mm_atomic.LoadUint64(&mmGetBalanceStatistics.expectedInvocations) + + return totalInvocations > 0 && (expectedInvocations == 0 || expectedInvocations == totalInvocations) +} + +// GetBalanceStatistics implements mm_service.UserService +func (mmGetBalanceStatistics *UserServiceMock) GetBalanceStatistics(ctx context.Context, userID int) (bp1 *mm_service.BalanceStatistics, err error) { + mm_atomic.AddUint64(&mmGetBalanceStatistics.beforeGetBalanceStatisticsCounter, 1) + defer mm_atomic.AddUint64(&mmGetBalanceStatistics.afterGetBalanceStatisticsCounter, 1) + + mmGetBalanceStatistics.t.Helper() + + if mmGetBalanceStatistics.inspectFuncGetBalanceStatistics != nil { + mmGetBalanceStatistics.inspectFuncGetBalanceStatistics(ctx, userID) + } + + mm_params := UserServiceMockGetBalanceStatisticsParams{ctx, userID} + + // Record call args + mmGetBalanceStatistics.GetBalanceStatisticsMock.mutex.Lock() + mmGetBalanceStatistics.GetBalanceStatisticsMock.callArgs = append(mmGetBalanceStatistics.GetBalanceStatisticsMock.callArgs, &mm_params) + mmGetBalanceStatistics.GetBalanceStatisticsMock.mutex.Unlock() + + for _, e := range mmGetBalanceStatistics.GetBalanceStatisticsMock.expectations { + if minimock.Equal(*e.params, mm_params) { + mm_atomic.AddUint64(&e.Counter, 1) + return e.results.bp1, e.results.err + } + } + + if mmGetBalanceStatistics.GetBalanceStatisticsMock.defaultExpectation != nil { + mm_atomic.AddUint64(&mmGetBalanceStatistics.GetBalanceStatisticsMock.defaultExpectation.Counter, 1) + mm_want := mmGetBalanceStatistics.GetBalanceStatisticsMock.defaultExpectation.params + mm_want_ptrs := mmGetBalanceStatistics.GetBalanceStatisticsMock.defaultExpectation.paramPtrs + + mm_got := UserServiceMockGetBalanceStatisticsParams{ctx, userID} + + if mm_want_ptrs != nil { + + if mm_want_ptrs.ctx != nil && !minimock.Equal(*mm_want_ptrs.ctx, mm_got.ctx) { + mmGetBalanceStatistics.t.Errorf("UserServiceMock.GetBalanceStatistics got unexpected parameter ctx, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmGetBalanceStatistics.GetBalanceStatisticsMock.defaultExpectation.expectationOrigins.originCtx, *mm_want_ptrs.ctx, mm_got.ctx, minimock.Diff(*mm_want_ptrs.ctx, mm_got.ctx)) + } + + if mm_want_ptrs.userID != nil && !minimock.Equal(*mm_want_ptrs.userID, mm_got.userID) { + mmGetBalanceStatistics.t.Errorf("UserServiceMock.GetBalanceStatistics got unexpected parameter userID, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmGetBalanceStatistics.GetBalanceStatisticsMock.defaultExpectation.expectationOrigins.originUserID, *mm_want_ptrs.userID, mm_got.userID, minimock.Diff(*mm_want_ptrs.userID, mm_got.userID)) + } + + } else if mm_want != nil && !minimock.Equal(*mm_want, mm_got) { + mmGetBalanceStatistics.t.Errorf("UserServiceMock.GetBalanceStatistics got unexpected parameters, expected at\n%s:\nwant: %#v\n got: %#v%s\n", + mmGetBalanceStatistics.GetBalanceStatisticsMock.defaultExpectation.expectationOrigins.origin, *mm_want, mm_got, minimock.Diff(*mm_want, mm_got)) + } + + mm_results := mmGetBalanceStatistics.GetBalanceStatisticsMock.defaultExpectation.results + if mm_results == nil { + mmGetBalanceStatistics.t.Fatal("No results are set for the UserServiceMock.GetBalanceStatistics") + } + return (*mm_results).bp1, (*mm_results).err + } + if mmGetBalanceStatistics.funcGetBalanceStatistics != nil { + return mmGetBalanceStatistics.funcGetBalanceStatistics(ctx, userID) + } + mmGetBalanceStatistics.t.Fatalf("Unexpected call to UserServiceMock.GetBalanceStatistics. %v %v", ctx, userID) + return +} + +// GetBalanceStatisticsAfterCounter returns a count of finished UserServiceMock.GetBalanceStatistics invocations +func (mmGetBalanceStatistics *UserServiceMock) GetBalanceStatisticsAfterCounter() uint64 { + return mm_atomic.LoadUint64(&mmGetBalanceStatistics.afterGetBalanceStatisticsCounter) +} + +// GetBalanceStatisticsBeforeCounter returns a count of UserServiceMock.GetBalanceStatistics invocations +func (mmGetBalanceStatistics *UserServiceMock) GetBalanceStatisticsBeforeCounter() uint64 { + return mm_atomic.LoadUint64(&mmGetBalanceStatistics.beforeGetBalanceStatisticsCounter) +} + +// Calls returns a list of arguments used in each call to UserServiceMock.GetBalanceStatistics. +// The list is in the same order as the calls were made (i.e. recent calls have a higher index) +func (mmGetBalanceStatistics *mUserServiceMockGetBalanceStatistics) Calls() []*UserServiceMockGetBalanceStatisticsParams { + mmGetBalanceStatistics.mutex.RLock() + + argCopy := make([]*UserServiceMockGetBalanceStatisticsParams, len(mmGetBalanceStatistics.callArgs)) + copy(argCopy, mmGetBalanceStatistics.callArgs) + + mmGetBalanceStatistics.mutex.RUnlock() + + return argCopy +} + +// MinimockGetBalanceStatisticsDone returns true if the count of the GetBalanceStatistics invocations corresponds +// the number of defined expectations +func (m *UserServiceMock) MinimockGetBalanceStatisticsDone() bool { + if m.GetBalanceStatisticsMock.optional { + // Optional methods provide '0 or more' call count restriction. + return true + } + + for _, e := range m.GetBalanceStatisticsMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + return false + } + } + + return m.GetBalanceStatisticsMock.invocationsDone() +} + +// MinimockGetBalanceStatisticsInspect logs each unmet expectation +func (m *UserServiceMock) MinimockGetBalanceStatisticsInspect() { + for _, e := range m.GetBalanceStatisticsMock.expectations { + if mm_atomic.LoadUint64(&e.Counter) < 1 { + m.t.Errorf("Expected call to UserServiceMock.GetBalanceStatistics at\n%s with params: %#v", e.expectationOrigins.origin, *e.params) + } + } + + afterGetBalanceStatisticsCounter := mm_atomic.LoadUint64(&m.afterGetBalanceStatisticsCounter) + // if default expectation was set then invocations count should be greater than zero + if m.GetBalanceStatisticsMock.defaultExpectation != nil && afterGetBalanceStatisticsCounter < 1 { + if m.GetBalanceStatisticsMock.defaultExpectation.params == nil { + m.t.Errorf("Expected call to UserServiceMock.GetBalanceStatistics at\n%s", m.GetBalanceStatisticsMock.defaultExpectation.returnOrigin) + } else { + m.t.Errorf("Expected call to UserServiceMock.GetBalanceStatistics at\n%s with params: %#v", m.GetBalanceStatisticsMock.defaultExpectation.expectationOrigins.origin, *m.GetBalanceStatisticsMock.defaultExpectation.params) + } + } + // if func was set then invocations count should be greater than zero + if m.funcGetBalanceStatistics != nil && afterGetBalanceStatisticsCounter < 1 { + m.t.Errorf("Expected call to UserServiceMock.GetBalanceStatistics at\n%s", m.funcGetBalanceStatisticsOrigin) + } + + if !m.GetBalanceStatisticsMock.invocationsDone() && afterGetBalanceStatisticsCounter > 0 { + m.t.Errorf("Expected %d calls to UserServiceMock.GetBalanceStatistics at\n%s but found %d calls", + mm_atomic.LoadUint64(&m.GetBalanceStatisticsMock.expectedInvocations), m.GetBalanceStatisticsMock.expectedInvocationsOrigin, afterGetBalanceStatisticsCounter) + } +} + type mUserServiceMockGetInfo struct { optional bool mock *UserServiceMock @@ -1098,6 +1451,8 @@ func (m *UserServiceMock) MinimockFinish() { if !m.minimockDone() { m.MinimockGetBalanceInspect() + m.MinimockGetBalanceStatisticsInspect() + m.MinimockGetInfoInspect() m.MinimockGetStatisticsInspect() @@ -1125,6 +1480,7 @@ func (m *UserServiceMock) minimockDone() bool { done := true return done && m.MinimockGetBalanceDone() && + m.MinimockGetBalanceStatisticsDone() && m.MinimockGetInfoDone() && m.MinimockGetStatisticsDone() } diff --git a/internal/model/invite.go b/internal/model/invite.go index a15e84c..2d81f28 100644 --- a/internal/model/invite.go +++ b/internal/model/invite.go @@ -7,7 +7,6 @@ type InviteCode struct { UserID int Code int64 CanBeUsedCount int - UsedCount int IsActive bool CreatedAt time.Time ExpiresAt time.Time diff --git a/internal/model/supplier.go b/internal/model/supplier.go index 285c6a1..95d9413 100644 --- a/internal/model/supplier.go +++ b/internal/model/supplier.go @@ -26,3 +26,9 @@ type TokenUsage struct { Type string CreatedAt time.Time } + +type WriteOffHistory struct { + OperationID string + Data string + Amount float64 +} diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go index c4a98ad..02f1ee1 100644 --- a/internal/repository/interfaces.go +++ b/internal/repository/interfaces.go @@ -37,7 +37,7 @@ type InviteRepository interface { CreateTx(ctx context.Context, tx pgx.Tx, invite *model.InviteCode) error FindByCode(ctx context.Context, code int64) (*model.InviteCode, error) FindActiveByCode(ctx context.Context, code int64) (*model.InviteCode, error) - IncrementUsedCount(ctx context.Context, code int64) error + FindActiveByUserID(ctx context.Context, userID int) (*model.InviteCode, error) DecrementCanBeUsedCountTx(ctx context.Context, tx pgx.Tx, code int64) error DeactivateExpired(ctx context.Context) (int, error) GetUserInvites(ctx context.Context, userID int) ([]*model.InviteCode, error) @@ -52,6 +52,7 @@ type RequestRepository interface { GetByID(ctx context.Context, id uuid.UUID) (*model.Request, error) GetDetailByID(ctx context.Context, id uuid.UUID) (*model.RequestDetail, error) GetUserStatistics(ctx context.Context, userID int) (requestsCount, suppliersCount, createdTZ int, err error) + CheckOwnership(ctx context.Context, requestID uuid.UUID, userID int) (bool, error) } type SupplierRepository interface { @@ -64,4 +65,5 @@ type SupplierRepository interface { type TokenUsageRepository interface { Create(ctx context.Context, usage *model.TokenUsage) error CreateTx(ctx context.Context, tx pgx.Tx, usage *model.TokenUsage) error + GetBalanceStatistics(ctx context.Context, userID int) (averageCost float64, history []*model.WriteOffHistory, err error) } diff --git a/internal/repository/invite.go b/internal/repository/invite.go index f1453ab..850a7a7 100644 --- a/internal/repository/invite.go +++ b/internal/repository/invite.go @@ -53,7 +53,7 @@ func (r *inviteRepository) createWithExecutor(ctx context.Context, exec DBTX, in func (r *inviteRepository) FindByCode(ctx context.Context, code int64) (*model.InviteCode, error) { query := r.qb.Select( - "id", "user_id", "code", "can_be_used_count", "used_count", + "id", "user_id", "code", "can_be_used_count", "is_active", "created_at", "expires_at", ).From("invite_codes").Where(sq.Eq{"code": code}) @@ -65,7 +65,7 @@ func (r *inviteRepository) FindByCode(ctx context.Context, code int64) (*model.I invite := &model.InviteCode{} err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan( &invite.ID, &invite.UserID, &invite.Code, &invite.CanBeUsedCount, - &invite.UsedCount, &invite.IsActive, &invite.CreatedAt, &invite.ExpiresAt, + &invite.IsActive, &invite.CreatedAt, &invite.ExpiresAt, ) if errors.Is(err, pgx.ErrNoRows) { @@ -80,13 +80,13 @@ func (r *inviteRepository) FindByCode(ctx context.Context, code int64) (*model.I func (r *inviteRepository) FindActiveByCode(ctx context.Context, code int64) (*model.InviteCode, error) { query := r.qb.Select( - "id", "user_id", "code", "can_be_used_count", "used_count", + "id", "user_id", "code", "can_be_used_count", "is_active", "created_at", "expires_at", ).From("invite_codes").Where(sq.And{ sq.Eq{"code": code}, sq.Eq{"is_active": true}, sq.Expr("expires_at > now()"), - sq.Expr("can_be_used_count > used_count"), + sq.Expr("can_be_used_count > 0"), }) sqlQuery, args, err := query.ToSql() @@ -97,7 +97,7 @@ func (r *inviteRepository) FindActiveByCode(ctx context.Context, code int64) (*m invite := &model.InviteCode{} err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan( &invite.ID, &invite.UserID, &invite.Code, &invite.CanBeUsedCount, - &invite.UsedCount, &invite.IsActive, &invite.CreatedAt, &invite.ExpiresAt, + &invite.IsActive, &invite.CreatedAt, &invite.ExpiresAt, ) if errors.Is(err, pgx.ErrNoRows) { @@ -110,40 +110,63 @@ func (r *inviteRepository) FindActiveByCode(ctx context.Context, code int64) (*m return invite, nil } -func (r *inviteRepository) IncrementUsedCount(ctx context.Context, code int64) error { - query := r.qb.Update("invite_codes"). - Set("used_count", sq.Expr("used_count + 1")). - Where(sq.Eq{"code": code}) +func (r *inviteRepository) FindActiveByUserID(ctx context.Context, userID int) (*model.InviteCode, error) { + query := r.qb.Select( + "id", "user_id", "code", "can_be_used_count", + "is_active", "created_at", "expires_at", + ).From("invite_codes").Where(sq.And{ + sq.Eq{"user_id": userID}, + sq.Eq{"is_active": true}, + sq.Expr("expires_at > now()"), + sq.Expr("can_be_used_count > 0"), + }).OrderBy("created_at DESC").Limit(1) sqlQuery, args, err := query.ToSql() if err != nil { - return errs.NewInternalError(errs.DatabaseError, "failed to build query", err) + return nil, errs.NewInternalError(errs.DatabaseError, "failed to build query", err) } - _, err = r.pool.Exec(ctx, sqlQuery, args...) + invite := &model.InviteCode{} + err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan( + &invite.ID, &invite.UserID, &invite.Code, &invite.CanBeUsedCount, + &invite.IsActive, &invite.CreatedAt, &invite.ExpiresAt, + ) + + if errors.Is(err, pgx.ErrNoRows) { + return nil, errs.NewBusinessError(errs.InviteInvalidOrExpired, "no active invite code found") + } if err != nil { - return errs.NewInternalError(errs.DatabaseError, "failed to increment used count", err) + return nil, errs.NewInternalError(errs.DatabaseError, "failed to find active invite code by user", err) } - return nil + return invite, nil } func (r *inviteRepository) DecrementCanBeUsedCountTx(ctx context.Context, tx pgx.Tx, code int64) error { query := r.qb.Update("invite_codes"). - Set("used_count", sq.Expr("used_count + 1")). - Set("is_active", sq.Expr("CASE WHEN used_count + 1 >= can_be_used_count THEN false ELSE is_active END")). - Where(sq.Eq{"code": code}) + Set("can_be_used_count", sq.Expr("can_be_used_count - 1")). + Set("is_active", sq.Expr("CASE WHEN can_be_used_count - 1 <= 0 THEN false ELSE is_active END")). + Where(sq.And{ + sq.Eq{"code": code}, + sq.Expr("can_be_used_count > 0"), + sq.Eq{"is_active": true}, + sq.Expr("expires_at > now()"), + }) sqlQuery, args, err := query.ToSql() if err != nil { return errs.NewInternalError(errs.DatabaseError, "failed to build query", err) } - _, err = tx.Exec(ctx, sqlQuery, args...) + result, err := tx.Exec(ctx, sqlQuery, args...) if err != nil { return errs.NewInternalError(errs.DatabaseError, "failed to decrement can_be_used_count", err) } + if result.RowsAffected() == 0 { + return errs.NewBusinessError(errs.InviteInvalidOrExpired, "invite code is invalid, expired, or exhausted") + } + return nil } @@ -170,7 +193,7 @@ func (r *inviteRepository) DeactivateExpired(ctx context.Context) (int, error) { func (r *inviteRepository) GetUserInvites(ctx context.Context, userID int) ([]*model.InviteCode, error) { query := r.qb.Select( - "id", "user_id", "code", "can_be_used_count", "used_count", + "id", "user_id", "code", "can_be_used_count", "is_active", "created_at", "expires_at", ).From("invite_codes"). Where(sq.Eq{"user_id": userID}). @@ -192,7 +215,7 @@ func (r *inviteRepository) GetUserInvites(ctx context.Context, userID int) ([]*m invite := &model.InviteCode{} err := rows.Scan( &invite.ID, &invite.UserID, &invite.Code, &invite.CanBeUsedCount, - &invite.UsedCount, &invite.IsActive, &invite.CreatedAt, &invite.ExpiresAt, + &invite.IsActive, &invite.CreatedAt, &invite.ExpiresAt, ) if err != nil { return nil, errs.NewInternalError(errs.DatabaseError, "failed to scan invite", err) diff --git a/internal/repository/request.go b/internal/repository/request.go index d31215e..cd22d00 100644 --- a/internal/repository/request.go +++ b/internal/repository/request.go @@ -153,30 +153,33 @@ func (r *requestRepository) GetByID(ctx context.Context, id uuid.UUID) (*model.R } func (r *requestRepository) GetDetailByID(ctx context.Context, id uuid.UUID) (*model.RequestDetail, error) { - sqlQuery := ` - SELECT - r.id AS request_id, - r.request_txt AS title, - r.final_update_tz AS mail_text, - COALESCE(json_agg( - json_build_object( - 'email', COALESCE(s.email, ''), - 'phone', COALESCE(s.phone, ''), - 'company_name', COALESCE(s.name, ''), - 'company_id', s.id, - 'url', COALESCE(s.url, '') - ) - ) FILTER (WHERE s.id IS NOT NULL), '[]') AS suppliers - FROM requests_for_suppliers r - LEFT JOIN suppliers s ON s.request_id = r.id - WHERE r.id = $1 - GROUP BY r.id, r.request_txt, r.final_update_tz - ` + query := r.qb.Select( + "r.id AS request_id", + "r.request_txt AS title", + "r.final_update_tz AS mail_text", + `COALESCE(json_agg( + json_build_object( + 'email', COALESCE(s.email, ''), + 'phone', COALESCE(s.phone, ''), + 'company_name', COALESCE(s.name, ''), + 'company_id', s.id, + 'url', COALESCE(s.url, '') + ) + ) FILTER (WHERE s.id IS NOT NULL), '[]') AS suppliers`, + ).From("requests_for_suppliers r"). + LeftJoin("suppliers s ON s.request_id = r.id"). + Where(sq.Eq{"r.id": id}). + GroupBy("r.id", "r.request_txt", "r.final_update_tz") + + sqlQuery, args, err := query.ToSql() + if err != nil { + return nil, errs.NewInternalError(errs.DatabaseError, "failed to build query", err) + } detail := &model.RequestDetail{} var suppliersJSON []byte - err := r.pool.QueryRow(ctx, sqlQuery, id).Scan( + err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan( &detail.RequestID, &detail.Title, &detail.MailText, &suppliersJSON, ) @@ -197,20 +200,44 @@ func (r *requestRepository) GetDetailByID(ctx context.Context, id uuid.UUID) (*m } func (r *requestRepository) GetUserStatistics(ctx context.Context, userID int) (requestsCount, suppliersCount, createdTZ int, err error) { - sqlQuery := ` - SELECT - COUNT(DISTINCT r.id) AS requests_count, - COUNT(s.id) AS suppliers_count, - COUNT(r.request_txt) AS created_tz - FROM requests_for_suppliers r - LEFT JOIN suppliers s ON s.request_id = r.id - WHERE r.user_id = $1 - ` + query := r.qb.Select( + "COUNT(DISTINCT r.id) AS requests_count", + "COUNT(s.id) AS suppliers_count", + "COUNT(DISTINCT CASE WHEN r.request_txt IS NOT NULL THEN r.id END) AS created_tz", + ).From("requests_for_suppliers r"). + LeftJoin("suppliers s ON s.request_id = r.id"). + Where(sq.Eq{"r.user_id": userID}) - err = r.pool.QueryRow(ctx, sqlQuery, userID).Scan(&requestsCount, &suppliersCount, &createdTZ) + sqlQuery, args, err := query.ToSql() + if err != nil { + return 0, 0, 0, errs.NewInternalError(errs.DatabaseError, "failed to build query", err) + } + + err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan(&requestsCount, &suppliersCount, &createdTZ) if err != nil { return 0, 0, 0, errs.NewInternalError(errs.DatabaseError, "failed to get statistics", err) } return requestsCount, suppliersCount, createdTZ, nil } + +func (r *requestRepository) CheckOwnership(ctx context.Context, requestID uuid.UUID, userID int) (bool, error) { + query := r.qb.Select("1").From("requests_for_suppliers"). + Where(sq.Eq{"id": requestID, "user_id": userID}) + + sqlQuery, args, err := query.ToSql() + if err != nil { + return false, errs.NewInternalError(errs.DatabaseError, "failed to build query", err) + } + + var exists int + err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan(&exists) + if errors.Is(err, pgx.ErrNoRows) { + return false, nil + } + if err != nil { + return false, errs.NewInternalError(errs.DatabaseError, "failed to check ownership", err) + } + + return true, nil +} diff --git a/internal/repository/supplier.go b/internal/repository/supplier.go index 18b693c..1767387 100644 --- a/internal/repository/supplier.go +++ b/internal/repository/supplier.go @@ -44,15 +44,30 @@ func (r *supplierRepository) bulkInsertWithExecutor(ctx context.Context, exec DB query = query.Values(requestID, s.Name, s.Email, s.Phone, s.Address, s.URL) } + query = query.Suffix("RETURNING id") + sqlQuery, args, err := query.ToSql() if err != nil { return errs.NewInternalError(errs.DatabaseError, "failed to build query", err) } - _, err = exec.Exec(ctx, sqlQuery, args...) + rows, err := exec.Query(ctx, sqlQuery, args...) if err != nil { return errs.NewInternalError(errs.DatabaseError, "failed to bulk insert suppliers", err) } + defer rows.Close() + + i := 0 + for rows.Next() { + if i >= len(suppliers) { + break + } + if err := rows.Scan(&suppliers[i].ID); err != nil { + return errs.NewInternalError(errs.DatabaseError, "failed to scan supplier id", err) + } + suppliers[i].RequestID = requestID + i++ + } return nil } diff --git a/internal/repository/token_usage.go b/internal/repository/token_usage.go index 30ec8e5..1c09db8 100644 --- a/internal/repository/token_usage.go +++ b/internal/repository/token_usage.go @@ -2,6 +2,7 @@ package repository import ( "context" + "strconv" "git.techease.ru/Smart-search/smart-search-back/internal/model" errs "git.techease.ru/Smart-search/smart-search-back/pkg/errors" @@ -49,3 +50,59 @@ func (r *tokenUsageRepository) createWithExecutor(ctx context.Context, exec DBTX return nil } + +func (r *tokenUsageRepository) GetBalanceStatistics(ctx context.Context, userID int) (float64, []*model.WriteOffHistory, error) { + avgQuery := r.qb.Select("ROUND(COALESCE(AVG(COALESCE(rtu.token_cost, 0)), 0)::numeric, 2)"). + From("request_token_usage rtu"). + Join("requests_for_suppliers rfs ON rtu.request_id = rfs.id"). + Where(sq.Eq{"rfs.user_id": userID}) + + avgSQL, avgArgs, err := avgQuery.ToSql() + if err != nil { + return 0, nil, errs.NewInternalError(errs.DatabaseError, "failed to build average query", err) + } + + var averageCost float64 + if err := r.pool.QueryRow(ctx, avgSQL, avgArgs...).Scan(&averageCost); err != nil { + return 0, nil, errs.NewInternalError(errs.DatabaseError, "failed to get average cost", err) + } + + historyQuery := r.qb.Select( + "rtu.id", + "TO_CHAR(rtu.created_at, 'DD-MM-YYYY')", + "ROUND(COALESCE(rtu.token_cost, 0)::numeric, 2)", + ). + From("request_token_usage rtu"). + Join("requests_for_suppliers rfs ON rtu.request_id = rfs.id"). + Where(sq.Eq{"rfs.user_id": userID}). + OrderBy("rtu.created_at DESC"). + Limit(8) + + historySQL, historyArgs, err := historyQuery.ToSql() + if err != nil { + return 0, nil, errs.NewInternalError(errs.DatabaseError, "failed to build history query", err) + } + + rows, err := r.pool.Query(ctx, historySQL, historyArgs...) + if err != nil { + return 0, nil, errs.NewInternalError(errs.DatabaseError, "failed to get write-off history", err) + } + defer rows.Close() + + var history []*model.WriteOffHistory + for rows.Next() { + var operationID int + var item model.WriteOffHistory + if err := rows.Scan(&operationID, &item.Data, &item.Amount); err != nil { + return 0, nil, errs.NewInternalError(errs.DatabaseError, "failed to scan write-off history", err) + } + item.OperationID = strconv.Itoa(operationID) + history = append(history, &item) + } + + if err := rows.Err(); err != nil { + return 0, nil, errs.NewInternalError(errs.DatabaseError, "failed to iterate write-off history", err) + } + + return averageCost, history, nil +} diff --git a/internal/service/auth.go b/internal/service/auth.go index a14d576..9474854 100644 --- a/internal/service/auth.go +++ b/internal/service/auth.go @@ -61,7 +61,7 @@ func (s *authService) Login(ctx context.Context, email, password, ip, userAgent RefreshToken: refreshToken, IP: ip, UserAgent: userAgent, - ExpiresAt: time.Now().Add(30 * 24 * time.Hour), + ExpiresAt: time.Now().Add(24 * time.Hour), } if err := s.sessionRepo.Create(ctx, session); err != nil { @@ -175,7 +175,7 @@ func (s *authService) Register(ctx context.Context, email, password, name, phone RefreshToken: refreshToken, IP: ip, UserAgent: userAgent, - ExpiresAt: time.Now().Add(30 * 24 * time.Hour), + ExpiresAt: time.Now().Add(24 * time.Hour), } if err := s.sessionRepo.Create(ctx, session); err != nil { diff --git a/internal/service/interfaces.go b/internal/service/interfaces.go index ffa09fa..1634290 100644 --- a/internal/service/interfaces.go +++ b/internal/service/interfaces.go @@ -20,20 +20,21 @@ type UserService interface { GetInfo(ctx context.Context, userID int) (*UserInfo, error) GetBalance(ctx context.Context, userID int) (float64, error) GetStatistics(ctx context.Context, userID int) (*Statistics, error) + GetBalanceStatistics(ctx context.Context, userID int) (*BalanceStatistics, error) } type InviteService interface { Generate(ctx context.Context, userID, maxUses, ttlDays int) (*model.InviteCode, error) - GetInfo(ctx context.Context, code int64) (*model.InviteCode, error) + GetInfo(ctx context.Context, userID int) (*model.InviteCode, error) } type RequestService interface { - CreateTZ(ctx context.Context, userID int, requestTxt string) (uuid.UUID, string, error) + CreateTZ(ctx context.Context, userID int, requestTxt string, fileData []byte, fileName string) (uuid.UUID, string, error) ApproveTZ(ctx context.Context, requestID uuid.UUID, tzText string, userID int) ([]*model.Supplier, error) GetMailingList(ctx context.Context, userID int) ([]*model.Request, error) - GetMailingListByID(ctx context.Context, requestID uuid.UUID) (*model.RequestDetail, error) + GetMailingListByID(ctx context.Context, requestID uuid.UUID, userID int) (*model.RequestDetail, error) } type SupplierService interface { - ExportExcel(ctx context.Context, requestID uuid.UUID) ([]byte, error) + ExportExcel(ctx context.Context, requestID uuid.UUID, userID int) ([]byte, error) } diff --git a/internal/service/invite.go b/internal/service/invite.go index e2b0f5b..2b92ca9 100644 --- a/internal/service/invite.go +++ b/internal/service/invite.go @@ -63,6 +63,6 @@ func (s *inviteService) Generate(ctx context.Context, userID, maxUses, ttlDays i return invite, nil } -func (s *inviteService) GetInfo(ctx context.Context, code int64) (*model.InviteCode, error) { - return s.inviteRepo.FindByCode(ctx, code) +func (s *inviteService) GetInfo(ctx context.Context, userID int) (*model.InviteCode, error) { + return s.inviteRepo.FindActiveByUserID(ctx, userID) } diff --git a/internal/service/request.go b/internal/service/request.go index 41483e8..769f19d 100644 --- a/internal/service/request.go +++ b/internal/service/request.go @@ -2,12 +2,14 @@ package service import ( "context" + "fmt" "math" "git.techease.ru/Smart-search/smart-search-back/internal/ai" "git.techease.ru/Smart-search/smart-search-back/internal/model" "git.techease.ru/Smart-search/smart-search-back/internal/repository" "git.techease.ru/Smart-search/smart-search-back/pkg/errors" + "git.techease.ru/Smart-search/smart-search-back/pkg/fileparser" "github.com/google/uuid" "github.com/jackc/pgx/v5" ) @@ -42,21 +44,37 @@ func NewRequestService( } } -func (s *requestService) CreateTZ(ctx context.Context, userID int, requestTxt string) (uuid.UUID, string, error) { +func (s *requestService) CreateTZ(ctx context.Context, userID int, requestTxt string, fileData []byte, fileName string) (uuid.UUID, string, error) { + combinedText := requestTxt + + if len(fileData) > 0 && fileName != "" { + fileContent, err := fileparser.ExtractText(fileData, fileName) + if err != nil { + return uuid.Nil, "", err + } + if fileContent != "" { + if combinedText != "" { + combinedText = fmt.Sprintf("%s\n\nСодержимое файла (%s):\n%s", combinedText, fileName, fileContent) + } else { + combinedText = fmt.Sprintf("Содержимое файла (%s):\n%s", fileName, fileContent) + } + } + } + req := &model.Request{ UserID: userID, - RequestTxt: requestTxt, + RequestTxt: combinedText, } if err := s.requestRepo.Create(ctx, req); err != nil { return uuid.Nil, "", err } - if requestTxt == "" { + if combinedText == "" { return req.ID, "", nil } - tzText, err := s.openAI.GenerateTZ(requestTxt) + tzText, err := s.openAI.GenerateTZ(combinedText) if err != nil { if err := s.requestRepo.UpdateWithTZ(ctx, req.ID, "", false); err != nil { return req.ID, "", err @@ -107,13 +125,20 @@ func (s *requestService) CreateTZ(ctx context.Context, userID int, requestTxt st } func (s *requestService) ApproveTZ(ctx context.Context, requestID uuid.UUID, tzText string, userID int) ([]*model.Supplier, error) { - if err := s.requestRepo.UpdateFinalTZ(ctx, requestID, tzText); err != nil { + isOwner, err := s.requestRepo.CheckOwnership(ctx, requestID, userID) + if err != nil { + return nil, err + } + if !isOwner { + return nil, errors.NewBusinessError(errors.PermissionDenied, "access denied to this request") + } + + if err = s.requestRepo.UpdateFinalTZ(ctx, requestID, tzText); err != nil { return nil, err } var suppliers []*model.Supplier var promptTokens, responseTokens int - var err error for attempt := 0; attempt < 3; attempt++ { suppliers, promptTokens, responseTokens, err = s.perplexity.FindSuppliers(tzText) @@ -169,6 +194,14 @@ func (s *requestService) GetMailingList(ctx context.Context, userID int) ([]*mod return s.requestRepo.GetByUserID(ctx, userID) } -func (s *requestService) GetMailingListByID(ctx context.Context, requestID uuid.UUID) (*model.RequestDetail, error) { +func (s *requestService) GetMailingListByID(ctx context.Context, requestID uuid.UUID, userID int) (*model.RequestDetail, error) { + isOwner, err := s.requestRepo.CheckOwnership(ctx, requestID, userID) + if err != nil { + return nil, err + } + if !isOwner { + return nil, errors.NewBusinessError(errors.PermissionDenied, "access denied to this request") + } + return s.requestRepo.GetDetailByID(ctx, requestID) } diff --git a/internal/service/supplier.go b/internal/service/supplier.go index dbdf441..8fc506e 100644 --- a/internal/service/supplier.go +++ b/internal/service/supplier.go @@ -5,21 +5,32 @@ import ( "fmt" "git.techease.ru/Smart-search/smart-search-back/internal/repository" + "git.techease.ru/Smart-search/smart-search-back/pkg/errors" "github.com/google/uuid" "github.com/xuri/excelize/v2" ) type supplierService struct { supplierRepo repository.SupplierRepository + requestRepo repository.RequestRepository } -func NewSupplierService(supplierRepo repository.SupplierRepository) SupplierService { +func NewSupplierService(supplierRepo repository.SupplierRepository, requestRepo repository.RequestRepository) SupplierService { return &supplierService{ supplierRepo: supplierRepo, + requestRepo: requestRepo, } } -func (s *supplierService) ExportExcel(ctx context.Context, requestID uuid.UUID) ([]byte, error) { +func (s *supplierService) ExportExcel(ctx context.Context, requestID uuid.UUID, userID int) ([]byte, error) { + isOwner, err := s.requestRepo.CheckOwnership(ctx, requestID, userID) + if err != nil { + return nil, err + } + if !isOwner { + return nil, errors.NewBusinessError(errors.PermissionDenied, "access denied to this request") + } + suppliers, err := s.supplierRepo.GetByRequestID(ctx, requestID) if err != nil { return nil, err diff --git a/internal/service/user.go b/internal/service/user.go index fcb19f0..9f2ccda 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -3,14 +3,16 @@ package service import ( "context" + "git.techease.ru/Smart-search/smart-search-back/internal/model" "git.techease.ru/Smart-search/smart-search-back/internal/repository" "git.techease.ru/Smart-search/smart-search-back/pkg/crypto" ) type userService struct { - userRepo repository.UserRepository - requestRepo repository.RequestRepository - cryptoHelper *crypto.Crypto + userRepo repository.UserRepository + requestRepo repository.RequestRepository + tokenUsageRepo repository.TokenUsageRepository + cryptoHelper *crypto.Crypto } type UserInfo struct { @@ -27,11 +29,17 @@ type Statistics struct { CreatedTZ int } -func NewUserService(userRepo repository.UserRepository, requestRepo repository.RequestRepository, cryptoSecret string) UserService { +type BalanceStatistics struct { + AverageCost float64 + WriteOffHistory []*model.WriteOffHistory +} + +func NewUserService(userRepo repository.UserRepository, requestRepo repository.RequestRepository, tokenUsageRepo repository.TokenUsageRepository, cryptoSecret string) UserService { return &userService{ - userRepo: userRepo, - requestRepo: requestRepo, - cryptoHelper: crypto.NewCrypto(cryptoSecret), + userRepo: userRepo, + requestRepo: requestRepo, + tokenUsageRepo: tokenUsageRepo, + cryptoHelper: crypto.NewCrypto(cryptoSecret), } } @@ -81,3 +89,15 @@ func (s *userService) GetStatistics(ctx context.Context, userID int) (*Statistic CreatedTZ: createdTZ, }, nil } + +func (s *userService) GetBalanceStatistics(ctx context.Context, userID int) (*BalanceStatistics, error) { + averageCost, history, err := s.tokenUsageRepo.GetBalanceStatistics(ctx, userID) + if err != nil { + return nil, err + } + + return &BalanceStatistics{ + AverageCost: averageCost, + WriteOffHistory: history, + }, nil +} diff --git a/internal/worker/invite_cleaner.go b/internal/worker/invite_cleaner.go index 7cb3680..f4aa9d1 100644 --- a/internal/worker/invite_cleaner.go +++ b/internal/worker/invite_cleaner.go @@ -2,9 +2,10 @@ package worker import ( "context" - "log" "time" + "go.uber.org/zap" + "git.techease.ru/Smart-search/smart-search-back/internal/repository" ) @@ -13,13 +14,15 @@ type InviteCleaner struct { ctx context.Context ticker *time.Ticker done chan bool + logger *zap.Logger } -func NewInviteCleaner(ctx context.Context, inviteRepo repository.InviteRepository) *InviteCleaner { +func NewInviteCleaner(ctx context.Context, inviteRepo repository.InviteRepository, logger *zap.Logger) *InviteCleaner { return &InviteCleaner{ inviteRepo: inviteRepo, ctx: ctx, done: make(chan bool), + logger: logger, } } @@ -36,13 +39,13 @@ func (w *InviteCleaner) Start() { case <-w.done: return case <-w.ctx.Done(): - log.Println("Invite cleaner context cancelled, stopping worker") + w.logger.Info("Invite cleaner context cancelled, stopping worker") return } } }() - log.Println("Invite cleaner worker started (runs every 6 hours)") + w.logger.Info("Invite cleaner worker started (runs every 6 hours)") } func (w *InviteCleaner) Stop() { @@ -53,17 +56,17 @@ func (w *InviteCleaner) Stop() { case w.done <- true: default: } - log.Println("Invite cleaner worker stopped") + w.logger.Info("Invite cleaner worker stopped") } func (w *InviteCleaner) deactivateExpiredInvites() { count, err := w.inviteRepo.DeactivateExpired(w.ctx) if err != nil { - log.Printf("Error deactivating expired invites: %v", err) + w.logger.Error("Error deactivating expired invites", zap.Error(err)) return } if count > 0 { - log.Printf("Deactivated %d expired invite codes", count) + w.logger.Info("Deactivated expired invite codes", zap.Int("count", count)) } } diff --git a/internal/worker/session_cleaner.go b/internal/worker/session_cleaner.go index 97b6ddf..ed542c2 100644 --- a/internal/worker/session_cleaner.go +++ b/internal/worker/session_cleaner.go @@ -2,9 +2,10 @@ package worker import ( "context" - "log" "time" + "go.uber.org/zap" + "git.techease.ru/Smart-search/smart-search-back/internal/repository" ) @@ -13,13 +14,15 @@ type SessionCleaner struct { ctx context.Context ticker *time.Ticker done chan bool + logger *zap.Logger } -func NewSessionCleaner(ctx context.Context, sessionRepo repository.SessionRepository) *SessionCleaner { +func NewSessionCleaner(ctx context.Context, sessionRepo repository.SessionRepository, logger *zap.Logger) *SessionCleaner { return &SessionCleaner{ sessionRepo: sessionRepo, ctx: ctx, done: make(chan bool), + logger: logger, } } @@ -36,13 +39,13 @@ func (w *SessionCleaner) Start() { case <-w.done: return case <-w.ctx.Done(): - log.Println("Session cleaner context cancelled, stopping worker") + w.logger.Info("Session cleaner context cancelled, stopping worker") return } } }() - log.Println("Session cleaner worker started (runs every hour)") + w.logger.Info("Session cleaner worker started (runs every hour)") } func (w *SessionCleaner) Stop() { @@ -53,17 +56,17 @@ func (w *SessionCleaner) Stop() { case w.done <- true: default: } - log.Println("Session cleaner worker stopped") + w.logger.Info("Session cleaner worker stopped") } func (w *SessionCleaner) cleanExpiredSessions() { count, err := w.sessionRepo.DeleteExpired(w.ctx) if err != nil { - log.Printf("Error cleaning expired sessions: %v", err) + w.logger.Error("Error cleaning expired sessions", zap.Error(err)) return } if count > 0 { - log.Printf("Cleaned %d expired sessions", count) + w.logger.Info("Cleaned expired sessions", zap.Int("count", count)) } } diff --git a/internal/worker/worker_test.go b/internal/worker/worker_test.go index c6b864e..4c4bf2e 100644 --- a/internal/worker/worker_test.go +++ b/internal/worker/worker_test.go @@ -8,6 +8,7 @@ import ( "github.com/gojuno/minimock/v3" "github.com/stretchr/testify/suite" + "go.uber.org/zap" "git.techease.ru/Smart-search/smart-search-back/internal/mocks" ) @@ -17,6 +18,7 @@ type WorkerSuite struct { ctx context.Context cancel context.CancelFunc ctrl *minimock.Controller + logger *zap.Logger } func TestWorkerSuite(t *testing.T) { @@ -26,6 +28,7 @@ func TestWorkerSuite(t *testing.T) { func (s *WorkerSuite) SetupTest() { s.ctx, s.cancel = context.WithCancel(context.Background()) s.ctrl = minimock.NewController(s.T()) + s.logger = zap.NewNop() } func (s *WorkerSuite) TearDownTest() { @@ -42,7 +45,7 @@ func (s *WorkerSuite) TestSessionCleaner_StartStop() { return 5, nil }) - cleaner := NewSessionCleaner(s.ctx, sessionRepo) + cleaner := NewSessionCleaner(s.ctx, sessionRepo, s.logger) cleaner.Start() @@ -62,7 +65,7 @@ func (s *WorkerSuite) TestSessionCleaner_ContextCancellation() { return 0, nil }) - cleaner := NewSessionCleaner(s.ctx, sessionRepo) + cleaner := NewSessionCleaner(s.ctx, sessionRepo, s.logger) cleaner.Start() @@ -84,7 +87,7 @@ func (s *WorkerSuite) TestInviteCleaner_StartStop() { return 3, nil }) - cleaner := NewInviteCleaner(s.ctx, inviteRepo) + cleaner := NewInviteCleaner(s.ctx, inviteRepo, s.logger) cleaner.Start() @@ -104,7 +107,7 @@ func (s *WorkerSuite) TestInviteCleaner_ContextCancellation() { return 0, nil }) - cleaner := NewInviteCleaner(s.ctx, inviteRepo) + cleaner := NewInviteCleaner(s.ctx, inviteRepo, s.logger) cleaner.Start() @@ -124,7 +127,7 @@ func (s *WorkerSuite) TestSessionCleaner_ConcurrentStops() { return 0, nil }) - cleaner := NewSessionCleaner(s.ctx, sessionRepo) + cleaner := NewSessionCleaner(s.ctx, sessionRepo, s.logger) cleaner.Start() @@ -151,7 +154,7 @@ func (s *WorkerSuite) TestInviteCleaner_ConcurrentStops() { return 0, nil }) - cleaner := NewInviteCleaner(s.ctx, inviteRepo) + cleaner := NewInviteCleaner(s.ctx, inviteRepo, s.logger) cleaner.Start() @@ -180,7 +183,7 @@ func (s *WorkerSuite) TestSessionCleaner_MultipleStartStop() { return 2, nil }) - cleaner := NewSessionCleaner(s.ctx, sessionRepo) + cleaner := NewSessionCleaner(s.ctx, sessionRepo, s.logger) for i := 0; i < 3; i++ { cleaner.Start() @@ -200,7 +203,7 @@ func (s *WorkerSuite) TestInviteCleaner_MultipleStartStop() { return 1, nil }) - cleaner := NewInviteCleaner(s.ctx, inviteRepo) + cleaner := NewInviteCleaner(s.ctx, inviteRepo, s.logger) for i := 0; i < 3; i++ { cleaner.Start() diff --git a/migrations/00009_drop_used_count_column.sql b/migrations/00009_drop_used_count_column.sql new file mode 100644 index 0000000..3737653 --- /dev/null +++ b/migrations/00009_drop_used_count_column.sql @@ -0,0 +1,5 @@ +-- +goose Up +ALTER TABLE invite_codes DROP COLUMN IF EXISTS used_count; + +-- +goose Down +ALTER TABLE invite_codes ADD COLUMN used_count INT DEFAULT 0; diff --git a/pkg/errors/codes.go b/pkg/errors/codes.go index 916bdbd..e6017e3 100644 --- a/pkg/errors/codes.go +++ b/pkg/errors/codes.go @@ -11,6 +11,9 @@ const ( InsufficientBalance = "INSUFFICIENT_BALANCE" UserNotFound = "USER_NOT_FOUND" RequestNotFound = "REQUEST_NOT_FOUND" + PermissionDenied = "PERMISSION_DENIED" + UnsupportedFileFormat = "UNSUPPORTED_FILE_FORMAT" + FileProcessingError = "FILE_PROCESSING_ERROR" DatabaseError = "DATABASE_ERROR" EncryptionError = "ENCRYPTION_ERROR" diff --git a/pkg/errors/errors.go b/pkg/errors/errors.go index b4f6ff3..463e200 100644 --- a/pkg/errors/errors.go +++ b/pkg/errors/errors.go @@ -5,6 +5,7 @@ import ( "fmt" "go.uber.org/zap" + "google.golang.org/genproto/googleapis/rpc/errdetails" "google.golang.org/grpc/codes" "google.golang.org/grpc/status" ) @@ -88,18 +89,31 @@ func ToGRPCError(err error, zapLogger *zap.Logger, method string) error { return status.Error(codes.Internal, "internal server error") } + var grpcCode codes.Code switch appErr.Code { case AuthInvalidCredentials, AuthMissing, AuthInvalidToken, RefreshInvalid: - return status.Error(codes.Unauthenticated, appErr.Message) + grpcCode = codes.Unauthenticated + case PermissionDenied: + grpcCode = codes.PermissionDenied case InviteLimitReached: - return status.Error(codes.ResourceExhausted, appErr.Message) - case InsufficientBalance, InviteInvalidOrExpired: - return status.Error(codes.FailedPrecondition, appErr.Message) + grpcCode = codes.ResourceExhausted + case InsufficientBalance: + grpcCode = codes.FailedPrecondition + case InviteInvalidOrExpired: + grpcCode = codes.NotFound case EmailAlreadyExists: - return status.Error(codes.AlreadyExists, appErr.Message) + grpcCode = codes.AlreadyExists case UserNotFound, RequestNotFound: - return status.Error(codes.NotFound, appErr.Message) + grpcCode = codes.NotFound default: - return status.Error(codes.Unknown, appErr.Message) + grpcCode = codes.Unknown } + + st, err := status.New(grpcCode, appErr.Message).WithDetails(&errdetails.ErrorInfo{ + Reason: appErr.Code, + }) + if err != nil { + return status.Error(grpcCode, appErr.Message) + } + return st.Err() } diff --git a/pkg/fileparser/parser.go b/pkg/fileparser/parser.go new file mode 100644 index 0000000..3d11386 --- /dev/null +++ b/pkg/fileparser/parser.go @@ -0,0 +1,113 @@ +package fileparser + +import ( + "archive/zip" + "bytes" + "encoding/xml" + "io" + "net/http" + "strings" + + "git.techease.ru/Smart-search/smart-search-back/pkg/errors" +) + +func ExtractText(data []byte, _ string) (string, error) { + if len(data) == 0 { + return "", nil + } + + mimeType := http.DetectContentType(data) + + switch { + case strings.HasPrefix(mimeType, "text/"): + return string(data), nil + case mimeType == "application/zip" || mimeType == "application/octet-stream": + if isDocx(data) { + return extractDocx(data) + } + return "", errors.NewBusinessError(errors.UnsupportedFileFormat, "поддерживаются только текстовые файлы (.txt) и документы Word (.docx)") + default: + return "", errors.NewBusinessError(errors.UnsupportedFileFormat, "неподдерживаемый формат файла: "+mimeType+", поддерживаются .txt и .docx") + } +} + +func isDocx(data []byte) bool { + reader, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + if err != nil { + return false + } + + for _, file := range reader.File { + if file.Name == "word/document.xml" { + return true + } + } + return false +} + +func extractDocx(data []byte) (string, error) { + reader, err := zip.NewReader(bytes.NewReader(data), int64(len(data))) + if err != nil { + return "", errors.NewInternalError(errors.FileProcessingError, "не удалось прочитать docx файл", err) + } + + var content string + for _, file := range reader.File { + if file.Name == "word/document.xml" { + rc, err := file.Open() + if err != nil { + return "", errors.NewInternalError(errors.FileProcessingError, "не удалось открыть содержимое документа", err) + } + defer func() { _ = rc.Close() }() + + xmlData, err := io.ReadAll(rc) + if err != nil { + return "", errors.NewInternalError(errors.FileProcessingError, "не удалось прочитать содержимое документа", err) + } + + content = extractTextFromXML(xmlData) + break + } + } + + return content, nil +} + +type docxDocument struct { + XMLName xml.Name `xml:"document"` + Body docxBody `xml:"body"` +} + +type docxBody struct { + Paragraphs []docxParagraph `xml:"p"` +} + +type docxParagraph struct { + Runs []docxRun `xml:"r"` +} + +type docxRun struct { + Text string `xml:"t"` +} + +func extractTextFromXML(data []byte) string { + var doc docxDocument + if err := xml.Unmarshal(data, &doc); err != nil { + return "" + } + + var result []string + for _, p := range doc.Body.Paragraphs { + var line []string + for _, r := range p.Runs { + if r.Text != "" { + line = append(line, r.Text) + } + } + if len(line) > 0 { + result = append(result, strings.Join(line, "")) + } + } + + return strings.Join(result, "\n") +} diff --git a/pkg/fileparser/parser_test.go b/pkg/fileparser/parser_test.go new file mode 100644 index 0000000..07650ca --- /dev/null +++ b/pkg/fileparser/parser_test.go @@ -0,0 +1,110 @@ +package fileparser + +import ( + "os" + "path/filepath" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestExtractText_EmptyData(t *testing.T) { + result, err := ExtractText(nil, "test.txt") + assert.NoError(t, err) + assert.Empty(t, result) + + result, err = ExtractText([]byte{}, "test.txt") + assert.NoError(t, err) + assert.Empty(t, result) +} + +func TestExtractText_PlainText(t *testing.T) { + content := "Тестовый текст для проверки" + result, err := ExtractText([]byte(content), "document.txt") + + assert.NoError(t, err) + assert.Equal(t, content, result) +} + +func TestExtractText_PlainTextWithNewlines(t *testing.T) { + content := "Первая строка\nВторая строка\nТретья строка" + result, err := ExtractText([]byte(content), "document.txt") + + assert.NoError(t, err) + assert.Equal(t, content, result) +} + +func TestExtractText_RealDocxFile(t *testing.T) { + testdataPath := filepath.Join("testdata", "test_document.docx") + data, err := os.ReadFile(testdataPath) + require.NoError(t, err, "не удалось прочитать тестовый файл") + + result, err := ExtractText(data, "тестовый.docx") + + assert.NoError(t, err) + assert.NotEmpty(t, result, "текст из docx не должен быть пустым") + t.Logf("Извлеченный текст из docx:\n%s", result) +} + +func TestExtractText_DocxWithAnyFilename(t *testing.T) { + testdataPath := filepath.Join("testdata", "test_document.docx") + data, err := os.ReadFile(testdataPath) + require.NoError(t, err) + + result1, err := ExtractText(data, "random_name_without_extension") + assert.NoError(t, err) + assert.NotEmpty(t, result1) + + result2, err := ExtractText(data, "document.pdf") + assert.NoError(t, err) + assert.NotEmpty(t, result2) + + assert.Equal(t, result1, result2, "результат должен быть одинаковым независимо от имени файла") +} + +func TestExtractText_UnsupportedFormat_PDF(t *testing.T) { + pdfHeader := []byte("%PDF-1.4\n") + result, err := ExtractText(pdfHeader, "document.pdf") + + assert.Error(t, err) + assert.Empty(t, result) + assert.Contains(t, err.Error(), "UNSUPPORTED_FILE_FORMAT") +} + +func TestExtractText_UnsupportedFormat_Image(t *testing.T) { + pngHeader := []byte{0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A} + result, err := ExtractText(pngHeader, "image.png") + + assert.Error(t, err) + assert.Empty(t, result) + assert.Contains(t, err.Error(), "UNSUPPORTED_FILE_FORMAT") +} + +func TestExtractText_InvalidDocx(t *testing.T) { + zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} + fakeZip := append(zipHeader, []byte("not a valid zip content")...) + + result, err := ExtractText(fakeZip, "fake.docx") + + assert.Error(t, err) + assert.Empty(t, result) +} + +func TestIsDocx_ValidDocx(t *testing.T) { + testdataPath := filepath.Join("testdata", "test_document.docx") + data, err := os.ReadFile(testdataPath) + require.NoError(t, err) + + assert.True(t, isDocx(data)) +} + +func TestIsDocx_RegularZip(t *testing.T) { + zipHeader := []byte{0x50, 0x4B, 0x03, 0x04} + assert.False(t, isDocx(zipHeader)) +} + +func TestIsDocx_NotZip(t *testing.T) { + textData := []byte("plain text content") + assert.False(t, isDocx(textData)) +} diff --git a/pkg/fileparser/testdata/test_document.docx b/pkg/fileparser/testdata/test_document.docx new file mode 100644 index 0000000..86bc44b Binary files /dev/null and b/pkg/fileparser/testdata/test_document.docx differ diff --git a/pkg/jwt/jwt.go b/pkg/jwt/jwt.go index d786523..7b90c91 100644 --- a/pkg/jwt/jwt.go +++ b/pkg/jwt/jwt.go @@ -11,7 +11,6 @@ import ( ) type Claims struct { - Sub string `json:"sub"` Type string `json:"type"` jwt.RegisteredClaims } @@ -19,12 +18,12 @@ type Claims struct { func GenerateAccessToken(userID int, secret string) (string, error) { now := time.Now() claims := Claims{ - Sub: strconv.Itoa(userID), Type: "access", RegisteredClaims: jwt.RegisteredClaims{ + Subject: strconv.Itoa(userID), ID: uuid.New().String(), IssuedAt: jwt.NewNumericDate(now), - ExpiresAt: jwt.NewNumericDate(now.Add(15 * time.Minute)), + ExpiresAt: jwt.NewNumericDate(now.Add(2 * time.Minute)), }, } @@ -35,12 +34,12 @@ func GenerateAccessToken(userID int, secret string) (string, error) { func GenerateRefreshToken(userID int, secret string) (string, error) { now := time.Now() claims := Claims{ - Sub: strconv.Itoa(userID), Type: "refresh", RegisteredClaims: jwt.RegisteredClaims{ + Subject: strconv.Itoa(userID), ID: uuid.New().String(), IssuedAt: jwt.NewNumericDate(now), - ExpiresAt: jwt.NewNumericDate(now.Add(30 * 24 * time.Hour)), + ExpiresAt: jwt.NewNumericDate(now.Add(24 * time.Hour)), }, } @@ -73,7 +72,7 @@ func GetUserIDFromToken(tokenString, secret string) (int, error) { return 0, err } - userID, err := strconv.Atoi(claims.Sub) + userID, err := strconv.Atoi(claims.Subject) if err != nil { return 0, fmt.Errorf("invalid user ID in token: %w", err) } diff --git a/pkg/pb/invite/invite.pb.go b/pkg/pb/invite/invite.pb.go index 6328374..9bf6326 100644 --- a/pkg/pb/invite/invite.pb.go +++ b/pkg/pb/invite/invite.pb.go @@ -144,7 +144,7 @@ func (x *GenerateResponse) GetExpiresAt() *timestamppb.Timestamp { type GetInfoRequest struct { state protoimpl.MessageState `protogen:"open.v1"` - Code string `protobuf:"bytes,1,opt,name=code,proto3" json:"code,omitempty"` + UserId int64 `protobuf:"varint,1,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -179,22 +179,19 @@ func (*GetInfoRequest) Descriptor() ([]byte, []int) { return file_invite_invite_proto_rawDescGZIP(), []int{2} } -func (x *GetInfoRequest) GetCode() string { +func (x *GetInfoRequest) GetUserId() int64 { if x != nil { - return x.Code + return x.UserId } - return "" + return 0 } type GetInfoResponse struct { state protoimpl.MessageState `protogen:"open.v1"` Code string `protobuf:"bytes,1,opt,name=code,proto3" json:"code,omitempty"` - UserId int64 `protobuf:"varint,2,opt,name=user_id,json=userId,proto3" json:"user_id,omitempty"` - CanBeUsedCount int32 `protobuf:"varint,3,opt,name=can_be_used_count,json=canBeUsedCount,proto3" json:"can_be_used_count,omitempty"` - UsedCount int32 `protobuf:"varint,4,opt,name=used_count,json=usedCount,proto3" json:"used_count,omitempty"` - ExpiresAt *timestamppb.Timestamp `protobuf:"bytes,5,opt,name=expires_at,json=expiresAt,proto3" json:"expires_at,omitempty"` - IsActive bool `protobuf:"varint,6,opt,name=is_active,json=isActive,proto3" json:"is_active,omitempty"` - CreatedAt *timestamppb.Timestamp `protobuf:"bytes,7,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` + CanBeUsedCount int32 `protobuf:"varint,2,opt,name=can_be_used_count,json=canBeUsedCount,proto3" json:"can_be_used_count,omitempty"` + ExpiresAt *timestamppb.Timestamp `protobuf:"bytes,3,opt,name=expires_at,json=expiresAt,proto3" json:"expires_at,omitempty"` + CreatedAt *timestamppb.Timestamp `protobuf:"bytes,4,opt,name=created_at,json=createdAt,proto3" json:"created_at,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } @@ -236,13 +233,6 @@ func (x *GetInfoResponse) GetCode() string { return "" } -func (x *GetInfoResponse) GetUserId() int64 { - if x != nil { - return x.UserId - } - return 0 -} - func (x *GetInfoResponse) GetCanBeUsedCount() int32 { if x != nil { return x.CanBeUsedCount @@ -250,13 +240,6 @@ func (x *GetInfoResponse) GetCanBeUsedCount() int32 { return 0 } -func (x *GetInfoResponse) GetUsedCount() int32 { - if x != nil { - return x.UsedCount - } - return 0 -} - func (x *GetInfoResponse) GetExpiresAt() *timestamppb.Timestamp { if x != nil { return x.ExpiresAt @@ -264,13 +247,6 @@ func (x *GetInfoResponse) GetExpiresAt() *timestamppb.Timestamp { return nil } -func (x *GetInfoResponse) GetIsActive() bool { - if x != nil { - return x.IsActive - } - return false -} - func (x *GetInfoResponse) GetCreatedAt() *timestamppb.Timestamp { if x != nil { return x.CreatedAt @@ -291,20 +267,16 @@ const file_invite_invite_proto_rawDesc = "" + "\x04code\x18\x01 \x01(\tR\x04code\x12\x19\n" + "\bmax_uses\x18\x02 \x01(\x05R\amaxUses\x129\n" + "\n" + - "expires_at\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\texpiresAt\"$\n" + - "\x0eGetInfoRequest\x12\x12\n" + - "\x04code\x18\x01 \x01(\tR\x04code\"\x9b\x02\n" + + "expires_at\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\texpiresAt\")\n" + + "\x0eGetInfoRequest\x12\x17\n" + + "\auser_id\x18\x01 \x01(\x03R\x06userId\"\xc6\x01\n" + "\x0fGetInfoResponse\x12\x12\n" + - "\x04code\x18\x01 \x01(\tR\x04code\x12\x17\n" + - "\auser_id\x18\x02 \x01(\x03R\x06userId\x12)\n" + - "\x11can_be_used_count\x18\x03 \x01(\x05R\x0ecanBeUsedCount\x12\x1d\n" + + "\x04code\x18\x01 \x01(\tR\x04code\x12)\n" + + "\x11can_be_used_count\x18\x02 \x01(\x05R\x0ecanBeUsedCount\x129\n" + "\n" + - "used_count\x18\x04 \x01(\x05R\tusedCount\x129\n" + + "expires_at\x18\x03 \x01(\v2\x1a.google.protobuf.TimestampR\texpiresAt\x129\n" + "\n" + - "expires_at\x18\x05 \x01(\v2\x1a.google.protobuf.TimestampR\texpiresAt\x12\x1b\n" + - "\tis_active\x18\x06 \x01(\bR\bisActive\x129\n" + - "\n" + - "created_at\x18\a \x01(\v2\x1a.google.protobuf.TimestampR\tcreatedAt2\x8a\x01\n" + + "created_at\x18\x04 \x01(\v2\x1a.google.protobuf.TimestampR\tcreatedAt2\x8a\x01\n" + "\rInviteService\x12=\n" + "\bGenerate\x12\x17.invite.GenerateRequest\x1a\x18.invite.GenerateResponse\x12:\n" + "\aGetInfo\x12\x16.invite.GetInfoRequest\x1a\x17.invite.GetInfoResponseB>Z request.MailingItem - 8, // 1: request.GetMailingListByIDResponse.item:type_name -> request.MailingItem - 9, // 2: request.MailingItem.created_at:type_name -> google.protobuf.Timestamp - 0, // 3: request.RequestService.CreateTZ:input_type -> request.CreateTZRequest - 2, // 4: request.RequestService.ApproveTZ:input_type -> request.ApproveTZRequest - 4, // 5: request.RequestService.GetMailingList:input_type -> request.GetMailingListRequest - 6, // 6: request.RequestService.GetMailingListByID:input_type -> request.GetMailingListByIDRequest - 1, // 7: request.RequestService.CreateTZ:output_type -> request.CreateTZResponse - 3, // 8: request.RequestService.ApproveTZ:output_type -> request.ApproveTZResponse - 5, // 9: request.RequestService.GetMailingList:output_type -> request.GetMailingListResponse - 7, // 10: request.RequestService.GetMailingListByID:output_type -> request.GetMailingListByIDResponse - 7, // [7:11] is the sub-list for method output_type - 3, // [3:7] is the sub-list for method input_type - 3, // [3:3] is the sub-list for extension type_name - 3, // [3:3] is the sub-list for extension extendee - 0, // [0:3] is the sub-list for field type_name + 4, // 0: request.ApproveTZResponse.suppliers:type_name -> request.Supplier + 9, // 1: request.GetMailingListResponse.items:type_name -> request.MailingItem + 10, // 2: request.GetMailingListByIDResponse.detail:type_name -> request.MailingDetail + 11, // 3: request.MailingItem.created_at:type_name -> google.protobuf.Timestamp + 4, // 4: request.MailingDetail.suppliers:type_name -> request.Supplier + 0, // 5: request.RequestService.CreateTZ:input_type -> request.CreateTZRequest + 2, // 6: request.RequestService.ApproveTZ:input_type -> request.ApproveTZRequest + 5, // 7: request.RequestService.GetMailingList:input_type -> request.GetMailingListRequest + 7, // 8: request.RequestService.GetMailingListByID:input_type -> request.GetMailingListByIDRequest + 1, // 9: request.RequestService.CreateTZ:output_type -> request.CreateTZResponse + 3, // 10: request.RequestService.ApproveTZ:output_type -> request.ApproveTZResponse + 6, // 11: request.RequestService.GetMailingList:output_type -> request.GetMailingListResponse + 8, // 12: request.RequestService.GetMailingListByID:output_type -> request.GetMailingListByIDResponse + 9, // [9:13] is the sub-list for method output_type + 5, // [5:9] is the sub-list for method input_type + 5, // [5:5] is the sub-list for extension type_name + 5, // [5:5] is the sub-list for extension extendee + 0, // [0:5] is the sub-list for field type_name } func init() { file_request_request_proto_init() } @@ -626,7 +796,7 @@ func file_request_request_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_request_request_proto_rawDesc), len(file_request_request_proto_rawDesc)), NumEnums: 0, - NumMessages: 9, + NumMessages: 11, NumExtensions: 0, NumServices: 1, }, diff --git a/pkg/pb/user/user.pb.go b/pkg/pb/user/user.pb.go index 863b9ba..652f1b4 100644 --- a/pkg/pb/user/user.pb.go +++ b/pkg/pb/user/user.pb.go @@ -274,13 +274,12 @@ func (x *GetStatisticsRequest) GetUserId() int64 { } type GetStatisticsResponse struct { - state protoimpl.MessageState `protogen:"open.v1"` - TotalRequests int32 `protobuf:"varint,1,opt,name=total_requests,json=totalRequests,proto3" json:"total_requests,omitempty"` - SuccessfulRequests int32 `protobuf:"varint,2,opt,name=successful_requests,json=successfulRequests,proto3" json:"successful_requests,omitempty"` - FailedRequests int32 `protobuf:"varint,3,opt,name=failed_requests,json=failedRequests,proto3" json:"failed_requests,omitempty"` - TotalSpent float64 `protobuf:"fixed64,4,opt,name=total_spent,json=totalSpent,proto3" json:"total_spent,omitempty"` - unknownFields protoimpl.UnknownFields - sizeCache protoimpl.SizeCache + state protoimpl.MessageState `protogen:"open.v1"` + SuppliersCount string `protobuf:"bytes,1,opt,name=suppliers_count,json=suppliersCount,proto3" json:"suppliers_count,omitempty"` + RequestsCount string `protobuf:"bytes,2,opt,name=requests_count,json=requestsCount,proto3" json:"requests_count,omitempty"` + CreatedTz string `protobuf:"bytes,3,opt,name=created_tz,json=createdTz,proto3" json:"created_tz,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache } func (x *GetStatisticsResponse) Reset() { @@ -313,32 +312,25 @@ func (*GetStatisticsResponse) Descriptor() ([]byte, []int) { return file_user_user_proto_rawDescGZIP(), []int{5} } -func (x *GetStatisticsResponse) GetTotalRequests() int32 { +func (x *GetStatisticsResponse) GetSuppliersCount() string { if x != nil { - return x.TotalRequests + return x.SuppliersCount } - return 0 + return "" } -func (x *GetStatisticsResponse) GetSuccessfulRequests() int32 { +func (x *GetStatisticsResponse) GetRequestsCount() string { if x != nil { - return x.SuccessfulRequests + return x.RequestsCount } - return 0 + return "" } -func (x *GetStatisticsResponse) GetFailedRequests() int32 { +func (x *GetStatisticsResponse) GetCreatedTz() string { if x != nil { - return x.FailedRequests + return x.CreatedTz } - return 0 -} - -func (x *GetStatisticsResponse) GetTotalSpent() float64 { - if x != nil { - return x.TotalSpent - } - return 0 + return "" } type GetBalanceStatisticsRequest struct { @@ -385,18 +377,77 @@ func (x *GetBalanceStatisticsRequest) GetUserId() int64 { return 0 } -type GetBalanceStatisticsResponse struct { +type WriteOffHistoryItem struct { state protoimpl.MessageState `protogen:"open.v1"` - Balance float64 `protobuf:"fixed64,1,opt,name=balance,proto3" json:"balance,omitempty"` - TotalRequests int32 `protobuf:"varint,2,opt,name=total_requests,json=totalRequests,proto3" json:"total_requests,omitempty"` - TotalSpent float64 `protobuf:"fixed64,3,opt,name=total_spent,json=totalSpent,proto3" json:"total_spent,omitempty"` + OperationId string `protobuf:"bytes,1,opt,name=operation_id,json=operationId,proto3" json:"operation_id,omitempty"` + Data string `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` + Amount float64 `protobuf:"fixed64,3,opt,name=amount,proto3" json:"amount,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } +func (x *WriteOffHistoryItem) Reset() { + *x = WriteOffHistoryItem{} + mi := &file_user_user_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WriteOffHistoryItem) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WriteOffHistoryItem) ProtoMessage() {} + +func (x *WriteOffHistoryItem) ProtoReflect() protoreflect.Message { + mi := &file_user_user_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WriteOffHistoryItem.ProtoReflect.Descriptor instead. +func (*WriteOffHistoryItem) Descriptor() ([]byte, []int) { + return file_user_user_proto_rawDescGZIP(), []int{7} +} + +func (x *WriteOffHistoryItem) GetOperationId() string { + if x != nil { + return x.OperationId + } + return "" +} + +func (x *WriteOffHistoryItem) GetData() string { + if x != nil { + return x.Data + } + return "" +} + +func (x *WriteOffHistoryItem) GetAmount() float64 { + if x != nil { + return x.Amount + } + return 0 +} + +type GetBalanceStatisticsResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + AverageCost float64 `protobuf:"fixed64,1,opt,name=average_cost,json=averageCost,proto3" json:"average_cost,omitempty"` + WriteOffHistory []*WriteOffHistoryItem `protobuf:"bytes,2,rep,name=write_off_history,json=writeOffHistory,proto3" json:"write_off_history,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + func (x *GetBalanceStatisticsResponse) Reset() { *x = GetBalanceStatisticsResponse{} - mi := &file_user_user_proto_msgTypes[7] + mi := &file_user_user_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -408,7 +459,7 @@ func (x *GetBalanceStatisticsResponse) String() string { func (*GetBalanceStatisticsResponse) ProtoMessage() {} func (x *GetBalanceStatisticsResponse) ProtoReflect() protoreflect.Message { - mi := &file_user_user_proto_msgTypes[7] + mi := &file_user_user_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -421,28 +472,21 @@ func (x *GetBalanceStatisticsResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetBalanceStatisticsResponse.ProtoReflect.Descriptor instead. func (*GetBalanceStatisticsResponse) Descriptor() ([]byte, []int) { - return file_user_user_proto_rawDescGZIP(), []int{7} + return file_user_user_proto_rawDescGZIP(), []int{8} } -func (x *GetBalanceStatisticsResponse) GetBalance() float64 { +func (x *GetBalanceStatisticsResponse) GetAverageCost() float64 { if x != nil { - return x.Balance + return x.AverageCost } return 0 } -func (x *GetBalanceStatisticsResponse) GetTotalRequests() int32 { +func (x *GetBalanceStatisticsResponse) GetWriteOffHistory() []*WriteOffHistoryItem { if x != nil { - return x.TotalRequests + return x.WriteOffHistory } - return 0 -} - -func (x *GetBalanceStatisticsResponse) GetTotalSpent() float64 { - if x != nil { - return x.TotalSpent - } - return 0 + return nil } var File_user_user_proto protoreflect.FileDescriptor @@ -463,20 +507,21 @@ const file_user_user_proto_rawDesc = "" + "\x12GetBalanceResponse\x12\x18\n" + "\abalance\x18\x01 \x01(\x01R\abalance\"/\n" + "\x14GetStatisticsRequest\x12\x17\n" + - "\auser_id\x18\x01 \x01(\x03R\x06userId\"\xb9\x01\n" + - "\x15GetStatisticsResponse\x12%\n" + - "\x0etotal_requests\x18\x01 \x01(\x05R\rtotalRequests\x12/\n" + - "\x13successful_requests\x18\x02 \x01(\x05R\x12successfulRequests\x12'\n" + - "\x0ffailed_requests\x18\x03 \x01(\x05R\x0efailedRequests\x12\x1f\n" + - "\vtotal_spent\x18\x04 \x01(\x01R\n" + - "totalSpent\"6\n" + + "\auser_id\x18\x01 \x01(\x03R\x06userId\"\x86\x01\n" + + "\x15GetStatisticsResponse\x12'\n" + + "\x0fsuppliers_count\x18\x01 \x01(\tR\x0esuppliersCount\x12%\n" + + "\x0erequests_count\x18\x02 \x01(\tR\rrequestsCount\x12\x1d\n" + + "\n" + + "created_tz\x18\x03 \x01(\tR\tcreatedTz\"6\n" + "\x1bGetBalanceStatisticsRequest\x12\x17\n" + - "\auser_id\x18\x01 \x01(\x03R\x06userId\"\x80\x01\n" + - "\x1cGetBalanceStatisticsResponse\x12\x18\n" + - "\abalance\x18\x01 \x01(\x01R\abalance\x12%\n" + - "\x0etotal_requests\x18\x02 \x01(\x05R\rtotalRequests\x12\x1f\n" + - "\vtotal_spent\x18\x03 \x01(\x01R\n" + - "totalSpent2\xaf\x02\n" + + "\auser_id\x18\x01 \x01(\x03R\x06userId\"d\n" + + "\x13WriteOffHistoryItem\x12!\n" + + "\foperation_id\x18\x01 \x01(\tR\voperationId\x12\x12\n" + + "\x04data\x18\x02 \x01(\tR\x04data\x12\x16\n" + + "\x06amount\x18\x03 \x01(\x01R\x06amount\"\x88\x01\n" + + "\x1cGetBalanceStatisticsResponse\x12!\n" + + "\faverage_cost\x18\x01 \x01(\x01R\vaverageCost\x12E\n" + + "\x11write_off_history\x18\x02 \x03(\v2\x19.user.WriteOffHistoryItemR\x0fwriteOffHistory2\xaf\x02\n" + "\vUserService\x126\n" + "\aGetInfo\x12\x14.user.GetInfoRequest\x1a\x15.user.GetInfoResponse\x12?\n" + "\n" + @@ -496,7 +541,7 @@ func file_user_user_proto_rawDescGZIP() []byte { return file_user_user_proto_rawDescData } -var file_user_user_proto_msgTypes = make([]protoimpl.MessageInfo, 8) +var file_user_user_proto_msgTypes = make([]protoimpl.MessageInfo, 9) var file_user_user_proto_goTypes = []any{ (*GetInfoRequest)(nil), // 0: user.GetInfoRequest (*GetInfoResponse)(nil), // 1: user.GetInfoResponse @@ -505,22 +550,24 @@ var file_user_user_proto_goTypes = []any{ (*GetStatisticsRequest)(nil), // 4: user.GetStatisticsRequest (*GetStatisticsResponse)(nil), // 5: user.GetStatisticsResponse (*GetBalanceStatisticsRequest)(nil), // 6: user.GetBalanceStatisticsRequest - (*GetBalanceStatisticsResponse)(nil), // 7: user.GetBalanceStatisticsResponse + (*WriteOffHistoryItem)(nil), // 7: user.WriteOffHistoryItem + (*GetBalanceStatisticsResponse)(nil), // 8: user.GetBalanceStatisticsResponse } var file_user_user_proto_depIdxs = []int32{ - 0, // 0: user.UserService.GetInfo:input_type -> user.GetInfoRequest - 2, // 1: user.UserService.GetBalance:input_type -> user.GetBalanceRequest - 4, // 2: user.UserService.GetStatistics:input_type -> user.GetStatisticsRequest - 6, // 3: user.UserService.GetBalanceStatistics:input_type -> user.GetBalanceStatisticsRequest - 1, // 4: user.UserService.GetInfo:output_type -> user.GetInfoResponse - 3, // 5: user.UserService.GetBalance:output_type -> user.GetBalanceResponse - 5, // 6: user.UserService.GetStatistics:output_type -> user.GetStatisticsResponse - 7, // 7: user.UserService.GetBalanceStatistics:output_type -> user.GetBalanceStatisticsResponse - 4, // [4:8] is the sub-list for method output_type - 0, // [0:4] is the sub-list for method input_type - 0, // [0:0] is the sub-list for extension type_name - 0, // [0:0] is the sub-list for extension extendee - 0, // [0:0] is the sub-list for field type_name + 7, // 0: user.GetBalanceStatisticsResponse.write_off_history:type_name -> user.WriteOffHistoryItem + 0, // 1: user.UserService.GetInfo:input_type -> user.GetInfoRequest + 2, // 2: user.UserService.GetBalance:input_type -> user.GetBalanceRequest + 4, // 3: user.UserService.GetStatistics:input_type -> user.GetStatisticsRequest + 6, // 4: user.UserService.GetBalanceStatistics:input_type -> user.GetBalanceStatisticsRequest + 1, // 5: user.UserService.GetInfo:output_type -> user.GetInfoResponse + 3, // 6: user.UserService.GetBalance:output_type -> user.GetBalanceResponse + 5, // 7: user.UserService.GetStatistics:output_type -> user.GetStatisticsResponse + 8, // 8: user.UserService.GetBalanceStatistics:output_type -> user.GetBalanceStatisticsResponse + 5, // [5:9] is the sub-list for method output_type + 1, // [1:5] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name } func init() { file_user_user_proto_init() } @@ -534,7 +581,7 @@ func file_user_user_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_user_user_proto_rawDesc), len(file_user_user_proto_rawDesc)), NumEnums: 0, - NumMessages: 8, + NumMessages: 9, NumExtensions: 0, NumServices: 1, }, diff --git a/tests/auth_handler_test.go b/tests/auth_handler_test.go index def4a1b..298ea97 100644 --- a/tests/auth_handler_test.go +++ b/tests/auth_handler_test.go @@ -208,7 +208,7 @@ func (s *IntegrationSuite) TestAuthHandler_RegisterInvalidInviteCode() { st, ok := status.FromError(err) s.True(ok) - s.Equal(codes.FailedPrecondition, st.Code()) + s.Equal(codes.NotFound, st.Code()) } func (s *IntegrationSuite) TestAuthHandler_RegisterExpiredInviteCode() { @@ -232,7 +232,7 @@ func (s *IntegrationSuite) TestAuthHandler_RegisterExpiredInviteCode() { st, ok := status.FromError(err) s.True(ok) - s.Equal(codes.FailedPrecondition, st.Code()) + s.Equal(codes.NotFound, st.Code()) } func (s *IntegrationSuite) TestAuthHandler_RegisterExhaustedInviteCode() { @@ -270,7 +270,7 @@ func (s *IntegrationSuite) TestAuthHandler_RegisterExhaustedInviteCode() { st, ok := status.FromError(err) s.True(ok) - s.Equal(codes.FailedPrecondition, st.Code()) + s.Equal(codes.NotFound, st.Code()) } func (s *IntegrationSuite) TestAuthHandler_RegisterDuplicateEmail() { diff --git a/tests/concurrent_ownership_test.go b/tests/concurrent_ownership_test.go new file mode 100644 index 0000000..5c21d1d --- /dev/null +++ b/tests/concurrent_ownership_test.go @@ -0,0 +1,243 @@ +package tests + +import ( + "sync" + "sync/atomic" + + authpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth" + requestpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/request" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func (s *IntegrationSuite) TestConcurrentOwnership_User2TriesApproveTZ_WhileUser1Creates() { + email1, password1, _ := s.createUniqueTestUser("owner1", 1000.0) + email2, password2, _ := s.createUniqueTestUser("attacker1", 1000.0) + + login1, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email1, + Password: password1, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validate1, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: login1.AccessToken, + }) + s.Require().NoError(err) + user1ID := validate1.UserId + + login2, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email2, + Password: password2, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validate2, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: login2.AccessToken, + }) + s.Require().NoError(err) + user2ID := validate2.UserId + + createResp, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: user1ID, + RequestTxt: "Request от User1 для теста ownership", + }) + s.Require().NoError(err) + requestID := createResp.RequestId + + var wg sync.WaitGroup + var user1Success, user2Denied int32 + + startBarrier := make(chan struct{}) + + wg.Add(1) + go func() { + defer wg.Done() + <-startBarrier + + _, err := s.requestClient.ApproveTZ(s.ctx, &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: "User1 approves", + UserId: user1ID, + }) + if err == nil { + atomic.AddInt32(&user1Success, 1) + } + }() + + for i := 0; i < 5; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-startBarrier + + _, err := s.requestClient.ApproveTZ(s.ctx, &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: "User2 tries to approve", + UserId: user2ID, + }) + if err != nil { + st, ok := status.FromError(err) + if ok && st.Code() == codes.PermissionDenied { + atomic.AddInt32(&user2Denied, 1) + } + } + }() + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("User1 success: %d, User2 denied: %d", user1Success, user2Denied) + + s.Equal(int32(5), user2Denied, + "Все попытки User2 должны быть отклонены с PermissionDenied") +} + +func (s *IntegrationSuite) TestConcurrentOwnership_ConcurrentApproveTZ_SameRequest() { + email, password, _ := s.createUniqueTestUser("concurrent_approve", 1000.0) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + }) + s.Require().NoError(err) + userID := validateResp.UserId + + createResp, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: userID, + RequestTxt: "Request для concurrent ApproveTZ", + }) + s.Require().NoError(err) + requestID := createResp.RequestId + + var wg sync.WaitGroup + var successCount int32 + goroutines := 5 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-startBarrier + + _, err := s.requestClient.ApproveTZ(s.ctx, &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: "Concurrent approve attempt", + UserId: userID, + }) + if err == nil { + atomic.AddInt32(&successCount, 1) + } + }(i) + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Concurrent ApproveTZ success count: %d", successCount) + + suppliersCount := s.getRequestSuppliersCount(requestID) + s.T().Logf("Total suppliers for request: %d", suppliersCount) +} + +func (s *IntegrationSuite) TestConcurrentOwnership_SessionIsolation_AfterLogout() { + email1, password1, _ := s.createUniqueTestUser("session_iso1", 1000.0) + email2, password2, _ := s.createUniqueTestUser("session_iso2", 1000.0) + + login1, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email1, + Password: password1, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + user1Token := login1.AccessToken + + validate1, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: user1Token, + }) + s.Require().NoError(err) + s.True(validate1.Valid) + + login2, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email2, + Password: password2, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + user2Token := login2.AccessToken + + _, err = s.authClient.Logout(s.ctx, &authpb.LogoutRequest{ + AccessToken: user1Token, + }) + s.Require().NoError(err) + + validate1After, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: user1Token, + }) + s.NoError(err) + s.False(validate1After.Valid, + "Токен User1 должен быть невалиден после logout") + + validate2After, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: user2Token, + }) + s.NoError(err) + s.True(validate2After.Valid, + "Токен User2 должен оставаться валидным после logout User1") + + var wg sync.WaitGroup + var user1Invalid, user2Valid int32 + goroutines := 10 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-startBarrier + + if idx%2 == 0 { + resp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: user1Token, + }) + if err == nil && !resp.Valid { + atomic.AddInt32(&user1Invalid, 1) + } + } else { + resp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: user2Token, + }) + if err == nil && resp.Valid { + atomic.AddInt32(&user2Valid, 1) + } + } + }(i) + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Session isolation - User1 invalid: %d, User2 valid: %d", user1Invalid, user2Valid) + + s.Equal(int32(goroutines/2), user1Invalid, + "Все проверки токена User1 должны показать invalid") + s.Equal(int32(goroutines/2), user2Valid, + "Все проверки токена User2 должны показать valid") +} diff --git a/tests/concurrent_registration_test.go b/tests/concurrent_registration_test.go new file mode 100644 index 0000000..1554d79 --- /dev/null +++ b/tests/concurrent_registration_test.go @@ -0,0 +1,178 @@ +package tests + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + authpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth" +) + +func (s *IntegrationSuite) TestConcurrent_Registration_WithSingleInviteCode() { + maxUses := 3 + inviteCode := s.createActiveInviteCode(maxUses) + + var wg sync.WaitGroup + var successCount int32 + var errorCount int32 + goroutines := 20 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + <-startBarrier + + email := fmt.Sprintf("concurrent_reg_%d_%d@example.com", idx, time.Now().UnixNano()) + + _, err := s.authClient.Register(s.ctx, &authpb.RegisterRequest{ + Email: email, + Password: "testpassword123", + Name: fmt.Sprintf("User %d", idx), + Phone: fmt.Sprintf("+1%010d", idx), + InviteCode: inviteCode, + Ip: "127.0.0.1", + UserAgent: "integration-test", + }) + + if err == nil { + atomic.AddInt32(&successCount, 1) + } else { + atomic.AddInt32(&errorCount, 1) + } + }(i) + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Registration results - Success: %d, Errors: %d", successCount, errorCount) + + s.LessOrEqual(int(successCount), maxUses, + "Количество успешных регистраций (%d) не должно превышать лимит invite-кода (%d)", successCount, maxUses) + + remainingUses := s.getInviteCodeUsageCount(inviteCode) + s.T().Logf("Remaining invite code uses: %d", remainingUses) + + s.Equal(maxUses-int(successCount), remainingUses, + "Оставшееся количество использований должно соответствовать успешным регистрациям") +} + +func (s *IntegrationSuite) TestConcurrent_Registration_InviteCodeDeactivation() { + maxUses := 2 + inviteCode := s.createActiveInviteCode(maxUses) + + s.True(s.isInviteCodeActive(inviteCode), "Invite code должен быть активен изначально") + + var wg sync.WaitGroup + var successCount int32 + goroutines := 10 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + <-startBarrier + + email := fmt.Sprintf("deactivation_test_%d_%d@example.com", idx, time.Now().UnixNano()) + + _, err := s.authClient.Register(s.ctx, &authpb.RegisterRequest{ + Email: email, + Password: "testpassword123", + Name: fmt.Sprintf("User %d", idx), + Phone: fmt.Sprintf("+2%010d", idx), + InviteCode: inviteCode, + Ip: "127.0.0.1", + UserAgent: "integration-test", + }) + + if err == nil { + atomic.AddInt32(&successCount, 1) + } + }(i) + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Registration success count: %d", successCount) + + s.LessOrEqual(int(successCount), maxUses, + "Не должно быть больше %d успешных регистраций", maxUses) + + remainingUses := s.getInviteCodeUsageCount(inviteCode) + s.GreaterOrEqual(remainingUses, 0, + "Количество использований не должно быть отрицательным") +} + +func (s *IntegrationSuite) TestConcurrent_Registration_MultipleInviteCodes() { + inviteCode1 := s.createActiveInviteCode(2) + inviteCode2 := s.createActiveInviteCode(2) + + var wg sync.WaitGroup + var success1, success2 int32 + goroutines := 10 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + <-startBarrier + + var code int64 + if idx%2 == 0 { + code = inviteCode1 + } else { + code = inviteCode2 + } + + email := fmt.Sprintf("multi_invite_%d_%d@example.com", idx, time.Now().UnixNano()) + + _, err := s.authClient.Register(s.ctx, &authpb.RegisterRequest{ + Email: email, + Password: "testpassword123", + Name: fmt.Sprintf("User %d", idx), + Phone: fmt.Sprintf("+3%010d", idx), + InviteCode: code, + Ip: "127.0.0.1", + UserAgent: "integration-test", + }) + + if err == nil { + if code == inviteCode1 { + atomic.AddInt32(&success1, 1) + } else { + atomic.AddInt32(&success2, 1) + } + } + }(i) + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Multi-invite results - Code1: %d, Code2: %d", success1, success2) + + s.LessOrEqual(int(success1), 2, + "Invite code 1 не должен превышать лимит") + s.LessOrEqual(int(success2), 2, + "Invite code 2 не должен превышать лимит") + + remaining1 := s.getInviteCodeUsageCount(inviteCode1) + remaining2 := s.getInviteCodeUsageCount(inviteCode2) + + s.Equal(2-int(success1), remaining1, + "Остаток invite code 1 должен соответствовать успешным регистрациям") + s.Equal(2-int(success2), remaining2, + "Остаток invite code 2 должен соответствовать успешным регистрациям") +} diff --git a/tests/concurrent_request_test.go b/tests/concurrent_request_test.go new file mode 100644 index 0000000..78a5fe5 --- /dev/null +++ b/tests/concurrent_request_test.go @@ -0,0 +1,219 @@ +package tests + +import ( + "sync" + "sync/atomic" + + authpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth" + requestpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/request" +) + +func (s *IntegrationSuite) TestConcurrentRequest_CreateTZ_LimitedBalance() { + initialBalance := 50.0 + email, password, userID := s.createUniqueTestUser("limited_balance_tz", initialBalance) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + }) + s.Require().NoError(err) + + var wg sync.WaitGroup + var successCount int32 + var errorCount int32 + goroutines := 20 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-startBarrier + + _, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: "Параллельный CreateTZ с ограниченным балансом", + }) + + if err == nil { + atomic.AddInt32(&successCount, 1) + } else { + atomic.AddInt32(&errorCount, 1) + } + }(i) + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("CreateTZ with limited balance - Success: %d, Errors: %d", successCount, errorCount) + + finalBalance := s.getUserBalance(userID) + s.T().Logf("Final balance: %.4f (initial: %.4f)", finalBalance, initialBalance) + + s.GreaterOrEqual(finalBalance, 0.0, + "Баланс не должен быть отрицательным") + + var requestsWithTZ int + err = s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM requests_for_suppliers WHERE user_id = $1 AND generated_tz = true", + userID, + ).Scan(&requestsWithTZ) + s.NoError(err) + + s.T().Logf("Requests with generated TZ: %d", requestsWithTZ) + + s.GreaterOrEqual(requestsWithTZ, 0, + "Количество запросов с TZ должно быть >= 0") + + s.LessOrEqual(requestsWithTZ, int(successCount), + "Количество запросов с TZ не должно превышать успешные операции") +} + +func (s *IntegrationSuite) TestConcurrentRequest_MultipleUsers_CreateTZ() { + user1Email, user1Pass, user1ID := s.createUniqueTestUser("multi_user1", 500.0) + user2Email, user2Pass, user2ID := s.createUniqueTestUser("multi_user2", 500.0) + user3Email, user3Pass, user3ID := s.createUniqueTestUser("multi_user3", 500.0) + + login1, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: user1Email, Password: user1Pass, Ip: "127.0.0.1", UserAgent: "test", + }) + s.Require().NoError(err) + validate1, _ := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{AccessToken: login1.AccessToken}) + + login2, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: user2Email, Password: user2Pass, Ip: "127.0.0.1", UserAgent: "test", + }) + s.Require().NoError(err) + validate2, _ := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{AccessToken: login2.AccessToken}) + + login3, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: user3Email, Password: user3Pass, Ip: "127.0.0.1", UserAgent: "test", + }) + s.Require().NoError(err) + validate3, _ := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{AccessToken: login3.AccessToken}) + + users := []struct { + userID int64 + id int + }{ + {validate1.UserId, user1ID}, + {validate2.UserId, user2ID}, + {validate3.UserId, user3ID}, + } + + var wg sync.WaitGroup + var totalSuccess int32 + requestsPerUser := 5 + + startBarrier := make(chan struct{}) + + for _, user := range users { + for i := 0; i < requestsPerUser; i++ { + wg.Add(1) + go func(uid int64) { + defer wg.Done() + <-startBarrier + + _, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: uid, + RequestTxt: "Multi-user concurrent CreateTZ", + }) + + if err == nil { + atomic.AddInt32(&totalSuccess, 1) + } + }(user.userID) + } + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Multi-user CreateTZ total success: %d", totalSuccess) + + for _, user := range users { + balance := s.getUserBalance(user.id) + s.T().Logf("User %d final balance: %.4f", user.id, balance) + s.GreaterOrEqual(balance, 0.0, + "Баланс пользователя %d не должен быть отрицательным", user.id) + } +} + +func (s *IntegrationSuite) TestConcurrentRequest_BalanceDeduction_Consistency() { + initialBalance := 1000.0 + email, password, userID := s.createUniqueTestUser("balance_consistency", initialBalance) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + }) + s.Require().NoError(err) + + var wg sync.WaitGroup + var successCount int32 + goroutines := 10 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + <-startBarrier + + _, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: "Balance consistency test", + }) + + if err == nil { + atomic.AddInt32(&successCount, 1) + } + }(i) + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Successful CreateTZ operations: %d", successCount) + + finalBalance := s.getUserBalance(userID) + balanceSpent := initialBalance - finalBalance + s.T().Logf("Balance spent: %.4f", balanceSpent) + + var totalTokenCost float64 + err = s.pool.QueryRow(s.ctx, ` + SELECT COALESCE(SUM(tu.token_cost), 0) + FROM request_token_usage tu + JOIN requests_for_suppliers r ON tu.request_id = r.id + WHERE r.user_id = $1 + `, userID).Scan(&totalTokenCost) + s.NoError(err) + + s.T().Logf("Total token cost from DB: %.4f", totalTokenCost) + + s.GreaterOrEqual(finalBalance, 0.0, + "Баланс не должен быть отрицательным") + + if totalTokenCost > 0 { + tolerance := 0.01 + s.InDelta(totalTokenCost, balanceSpent, tolerance, + "Сумма token_cost должна соответствовать списанному балансу") + } +} diff --git a/tests/edge_cases_test.go b/tests/edge_cases_test.go index 42fe9a7..77bf4ad 100644 --- a/tests/edge_cases_test.go +++ b/tests/edge_cases_test.go @@ -283,3 +283,140 @@ func (s *IntegrationSuite) TestEdgeCase_LoginWithWrongPassword() { s.True(ok) s.Equal(codes.Unauthenticated, st.Code()) } + +func (s *IntegrationSuite) TestEdgeCase_ApproveTZTwice() { + ctx := context.Background() + + loginReq := &authpb.LoginRequest{ + Email: "test@example.com", + Password: "testpassword", + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + loginResp, err := s.authClient.Login(ctx, loginReq) + s.NoError(err) + + validateReq := &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + } + + validateResp, err := s.authClient.Validate(ctx, validateReq) + s.NoError(err) + userID := validateResp.UserId + + createReq := &requestpb.CreateTZRequest{ + UserId: userID, + RequestTxt: "Тест двойного approve", + } + + createResp, err := s.requestClient.CreateTZ(ctx, createReq) + s.NoError(err) + requestID := createResp.RequestId + + approveReq1 := &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: "Первое утверждение", + UserId: userID, + } + + approveResp1, err := s.requestClient.ApproveTZ(ctx, approveReq1) + s.NoError(err) + s.NotEmpty(approveResp1.RequestId) + + approveReq2 := &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: "Второе утверждение", + UserId: userID, + } + + approveResp2, err := s.requestClient.ApproveTZ(ctx, approveReq2) + s.NoError(err) + s.NotEmpty(approveResp2.RequestId) +} + +func (s *IntegrationSuite) TestEdgeCase_CreateTZWithVeryLongText() { + ctx := context.Background() + + loginReq := &authpb.LoginRequest{ + Email: "test@example.com", + Password: "testpassword", + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + loginResp, err := s.authClient.Login(ctx, loginReq) + s.NoError(err) + + validateReq := &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + } + + validateResp, err := s.authClient.Validate(ctx, validateReq) + s.NoError(err) + + _, err = s.pool.Exec(ctx, "UPDATE users SET balance = 10000 WHERE id = $1", validateResp.UserId) + s.NoError(err) + + longText := "Нужны поставщики. " + for i := 0; i < 500; i++ { + longText += "Дополнительные требования к качеству и срокам поставки материалов. " + } + + req := &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: longText, + } + + resp, err := s.requestClient.CreateTZ(ctx, req) + s.NoError(err) + s.NotNil(resp) + s.NotEmpty(resp.RequestId) + s.NotEmpty(resp.TzText) +} + +func (s *IntegrationSuite) TestEdgeCase_ApproveTZWithVeryLongFinalTZ() { + ctx := context.Background() + + loginReq := &authpb.LoginRequest{ + Email: "test@example.com", + Password: "testpassword", + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + loginResp, err := s.authClient.Login(ctx, loginReq) + s.NoError(err) + + validateReq := &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + } + + validateResp, err := s.authClient.Validate(ctx, validateReq) + s.NoError(err) + userID := validateResp.UserId + + createReq := &requestpb.CreateTZRequest{ + UserId: userID, + RequestTxt: "Тест длинного ТЗ", + } + + createResp, err := s.requestClient.CreateTZ(ctx, createReq) + s.NoError(err) + requestID := createResp.RequestId + + longFinalTZ := "ТЕХНИЧЕСКОЕ ЗАДАНИЕ\n\n" + for i := 0; i < 500; i++ { + longFinalTZ += "Пункт требований с детальным описанием спецификации и условий поставки. " + } + + approveReq := &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: longFinalTZ, + UserId: userID, + } + + approveResp, err := s.requestClient.ApproveTZ(ctx, approveReq) + s.NoError(err) + s.NotEmpty(approveResp.RequestId) +} diff --git a/tests/full_flow_test.go b/tests/full_flow_test.go index e71fe59..1a76c26 100644 --- a/tests/full_flow_test.go +++ b/tests/full_flow_test.go @@ -76,7 +76,7 @@ func (s *IntegrationSuite) TestFullFlow_CompleteRequestLifecycle() { s.T().Logf("ApproveTZ failed: %v", err) return } - s.True(approveTZResp.Success) + s.NotEmpty(approveTZResp.RequestId) getMailingListReq := &requestpb.GetMailingListRequest{ UserId: userID, @@ -97,8 +97,8 @@ func (s *IntegrationSuite) TestFullFlow_CompleteRequestLifecycle() { s.T().Logf("GetMailingListByID failed: %v", err) return } - s.NotNil(mailingListByIDResp.Item) - s.Equal(requestID, mailingListByIDResp.Item.RequestId) + s.NotNil(mailingListByIDResp.Detail) + s.Equal(requestID, mailingListByIDResp.Detail.RequestId) exportExcelReq := &supplierpb.ExportExcelRequest{ RequestId: requestID, @@ -119,7 +119,7 @@ func (s *IntegrationSuite) TestFullFlow_CompleteRequestLifecycle() { statisticsResp, err := s.userClient.GetStatistics(ctx, getStatisticsReq) s.NoError(err) - s.GreaterOrEqual(statisticsResp.TotalRequests, int32(0)) + s.NotEmpty(statisticsResp.RequestsCount) logoutReq := &authpb.LogoutRequest{ AccessToken: loginResp.AccessToken, @@ -164,16 +164,90 @@ func (s *IntegrationSuite) TestFullFlow_InviteCodeLifecycle() { inviteCode := generateInviteResp.Code getInviteInfoReq := &invitepb.GetInfoRequest{ - Code: inviteCode, + UserId: userID, } inviteInfoResp, err := s.inviteClient.GetInfo(ctx, getInviteInfoReq) s.NoError(err) s.Equal(inviteCode, inviteInfoResp.Code) - s.Equal(userID, inviteInfoResp.UserId) s.Equal(generateInviteResp.MaxUses, inviteInfoResp.CanBeUsedCount) - s.Equal(int32(0), inviteInfoResp.UsedCount) - s.True(inviteInfoResp.IsActive) + + logoutReq := &authpb.LogoutRequest{ + AccessToken: loginResp.AccessToken, + } + + logoutResp, err := s.authClient.Logout(ctx, logoutReq) + s.NoError(err) + s.True(logoutResp.Success) +} + +func (s *IntegrationSuite) TestFullFlow_CreateTZ_ApproveTZ_GetMailingListByID_ExportExcel() { + ctx := context.Background() + + loginReq := &authpb.LoginRequest{ + Email: "test@example.com", + Password: "testpassword", + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + loginResp, err := s.authClient.Login(ctx, loginReq) + s.NoError(err) + s.NotEmpty(loginResp.AccessToken) + + validateReq := &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + } + + validateResp, err := s.authClient.Validate(ctx, validateReq) + s.NoError(err) + s.True(validateResp.Valid) + userID := validateResp.UserId + + createTZReq := &requestpb.CreateTZRequest{ + UserId: userID, + RequestTxt: "Нужны поставщики офисной мебели: столы 20 шт, стулья 50 шт", + } + + createTZResp, err := s.requestClient.CreateTZ(ctx, createTZReq) + s.NoError(err) + s.NotEmpty(createTZResp.RequestId) + s.NotEmpty(createTZResp.TzText) + requestID := createTZResp.RequestId + + approveTZReq := &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: createTZResp.TzText, + UserId: userID, + } + + approveTZResp, err := s.requestClient.ApproveTZ(ctx, approveTZReq) + s.NoError(err) + s.NotEmpty(approveTZResp.RequestId) + + getMailingListByIDReq := &requestpb.GetMailingListByIDRequest{ + RequestId: requestID, + UserId: userID, + } + + mailingListByIDResp, err := s.requestClient.GetMailingListByID(ctx, getMailingListByIDReq) + s.NoError(err) + s.NotNil(mailingListByIDResp.Detail) + s.Equal(requestID, mailingListByIDResp.Detail.RequestId) + s.Greater(len(mailingListByIDResp.Detail.Suppliers), 0) + + exportExcelReq := &supplierpb.ExportExcelRequest{ + RequestId: requestID, + UserId: userID, + } + + exportExcelResp, err := s.supplierClient.ExportExcel(ctx, exportExcelReq) + s.NoError(err) + s.NotNil(exportExcelResp) + s.NotEmpty(exportExcelResp.FileName) + s.NotEmpty(exportExcelResp.FileData) + s.Equal("application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", exportExcelResp.MimeType) + s.Greater(len(exportExcelResp.FileData), 0) logoutReq := &authpb.LogoutRequest{ AccessToken: loginResp.AccessToken, diff --git a/tests/idempotency_test.go b/tests/idempotency_test.go new file mode 100644 index 0000000..a417ebc --- /dev/null +++ b/tests/idempotency_test.go @@ -0,0 +1,155 @@ +package tests + +import ( + "fmt" + "time" + + authpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth" + requestpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/request" +) + +func (s *IntegrationSuite) TestIdempotency_DoubleCreateTZ_CreatesTwoRequests() { + email, password, userID := s.createUniqueTestUser("idempotency_tz", 1000.0) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + }) + s.Require().NoError(err) + + requestText := "Одинаковый текст запроса для теста идемпотентности" + + resp1, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: requestText, + }) + s.Require().NoError(err) + requestID1 := resp1.RequestId + + resp2, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: requestText, + }) + s.Require().NoError(err) + requestID2 := resp2.RequestId + + s.T().Logf("Request 1 ID: %s", requestID1) + s.T().Logf("Request 2 ID: %s", requestID2) + + s.NotEqual(requestID1, requestID2, + "Два вызова CreateTZ должны создать два разных request") + + var requestCount int + err = s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM requests_for_suppliers WHERE user_id = $1 AND request_txt = $2", + userID, requestText, + ).Scan(&requestCount) + s.NoError(err) + + s.Equal(2, requestCount, + "Должно быть создано 2 запроса с одинаковым текстом") +} + +func (s *IntegrationSuite) TestIdempotency_DoubleRegister_SameInviteCode() { + inviteCode := s.createActiveInviteCode(5) + + email1 := fmt.Sprintf("double_reg1_%d@example.com", time.Now().UnixNano()) + + resp1, err := s.authClient.Register(s.ctx, &authpb.RegisterRequest{ + Email: email1, + Password: "testpassword", + Name: "User 1", + Phone: fmt.Sprintf("+1%010d", time.Now().UnixNano()%10000000000), + InviteCode: inviteCode, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + s.NotEmpty(resp1.AccessToken) + + email2 := fmt.Sprintf("double_reg2_%d@example.com", time.Now().UnixNano()) + + resp2, err := s.authClient.Register(s.ctx, &authpb.RegisterRequest{ + Email: email2, + Password: "testpassword", + Name: "User 2", + Phone: fmt.Sprintf("+2%010d", time.Now().UnixNano()%10000000000), + InviteCode: inviteCode, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + s.NotEmpty(resp2.AccessToken) + + remainingUses := s.getInviteCodeUsageCount(inviteCode) + s.T().Logf("Remaining invite uses: %d", remainingUses) + + s.Equal(3, remainingUses, + "После двух регистраций должно остаться 3 использования (5-2)") + + validate1, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: resp1.AccessToken, + }) + s.NoError(err) + + validate2, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: resp2.AccessToken, + }) + s.NoError(err) + + s.NotEqual(validate1.UserId, validate2.UserId, + "Должны быть созданы два разных пользователя") +} + +func (s *IntegrationSuite) TestIdempotency_DoubleLogout_SameToken() { + email, password, _ := s.createUniqueTestUser("double_logout", 100.0) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + accessToken := loginResp.AccessToken + + validateBefore, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: accessToken, + }) + s.NoError(err) + s.True(validateBefore.Valid) + + logout1, err := s.authClient.Logout(s.ctx, &authpb.LogoutRequest{ + AccessToken: accessToken, + }) + s.NoError(err) + s.True(logout1.Success) + + validateAfter1, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: accessToken, + }) + s.NoError(err) + s.False(validateAfter1.Valid, + "Токен должен быть невалиден после первого logout") + + logout2, err := s.authClient.Logout(s.ctx, &authpb.LogoutRequest{ + AccessToken: accessToken, + }) + s.NoError(err) + s.True(logout2.Success, + "Повторный logout должен быть успешным (идемпотентность)") + + validateAfter2, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: accessToken, + }) + s.NoError(err) + s.False(validateAfter2.Valid, + "Токен должен оставаться невалидным после повторного logout") +} diff --git a/tests/integration_suite_test.go b/tests/integration_suite_test.go index ae3d204..b1f745f 100644 --- a/tests/integration_suite_test.go +++ b/tests/integration_suite_test.go @@ -78,7 +78,8 @@ func (s *IntegrationSuite) SetupSuite() { s.T().Logf("PostgreSQL connection string: %s", connStr) s.T().Log("Running migrations...") - err = database.RunMigrationsFromPath(connStr, "../migrations") + logger, _ := zap.NewDevelopment() + err = database.RunMigrationsFromPath(connStr, "../migrations", logger) s.Require().NoError(err) s.T().Log("Creating connection pool...") @@ -94,7 +95,6 @@ func (s *IntegrationSuite) SetupSuite() { s.Require().NoError(err) s.T().Log("Creating gRPC server...") - logger, _ := zap.NewDevelopment() authHandler, userHandler, inviteHandler, requestHandler, supplierHandler := grpchandlers.NewHandlers( pool, @@ -179,8 +179,8 @@ func (s *IntegrationSuite) createTestUser(email, password string) { func (s *IntegrationSuite) createActiveInviteCode(canBeUsedCount int) int64 { var inviteCode int64 query := ` - INSERT INTO invite_codes (user_id, code, can_be_used_count, used_count, is_active, expires_at) - VALUES (1, FLOOR(100000 + RANDOM() * 900000)::bigint, $1, 0, true, NOW() + INTERVAL '30 days') + INSERT INTO invite_codes (user_id, code, can_be_used_count, is_active, expires_at) + VALUES (1, FLOOR(100000 + RANDOM() * 900000)::bigint, $1, true, NOW() + INTERVAL '30 days') RETURNING code ` err := s.pool.QueryRow(s.ctx, query, canBeUsedCount).Scan(&inviteCode) @@ -191,8 +191,8 @@ func (s *IntegrationSuite) createActiveInviteCode(canBeUsedCount int) int64 { func (s *IntegrationSuite) createExpiredInviteCode() int64 { var inviteCode int64 query := ` - INSERT INTO invite_codes (user_id, code, can_be_used_count, used_count, is_active, expires_at) - VALUES (1, FLOOR(100000 + RANDOM() * 900000)::bigint, 5, 0, true, NOW() - INTERVAL '1 day') + INSERT INTO invite_codes (user_id, code, can_be_used_count, is_active, expires_at) + VALUES (1, FLOOR(100000 + RANDOM() * 900000)::bigint, 5, true, NOW() - INTERVAL '1 day') RETURNING code ` err := s.pool.QueryRow(s.ctx, query).Scan(&inviteCode) @@ -233,3 +233,146 @@ func (s *IntegrationSuite) TearDownTest() { _, _ = s.pool.Exec(s.ctx, "DELETE FROM suppliers") _, _ = s.pool.Exec(s.ctx, "DELETE FROM requests_for_suppliers") } + +func (s *IntegrationSuite) createSecondTestUser() (email string, password string, userID int64) { + email = "second_user@example.com" + password = "secondpassword" + + cryptoHelper := crypto.NewCrypto(testCryptoSecret) + + encryptedEmail, err := cryptoHelper.Encrypt(email) + s.Require().NoError(err) + + encryptedPhone, err := cryptoHelper.Encrypt("+9876543210") + s.Require().NoError(err) + + encryptedUserName, err := cryptoHelper.Encrypt("Second User") + s.Require().NoError(err) + + emailHash := cryptoHelper.EmailHash(email) + passwordHash := crypto.PasswordHash(password) + + query := ` + INSERT INTO users (email, email_hash, password_hash, phone, user_name, company_name, balance, payment_status, invites_issued, invites_limit) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + ON CONFLICT (email_hash) DO UPDATE SET balance = $7 + RETURNING id + ` + + err = s.pool.QueryRow(s.ctx, query, + encryptedEmail, + emailHash, + passwordHash, + encryptedPhone, + encryptedUserName, + "Second Company", + 1000.0, + "active", + 0, + 10, + ).Scan(&userID) + s.Require().NoError(err) + + return email, password, userID +} + +func (s *IntegrationSuite) getInviteCodeUsageCount(code int64) int { + var count int + err := s.pool.QueryRow(s.ctx, + "SELECT can_be_used_count FROM invite_codes WHERE code = $1", + code, + ).Scan(&count) + if err != nil { + return -1 + } + return count +} + +func (s *IntegrationSuite) getRequestSuppliersCount(requestID string) int { + var count int + err := s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM suppliers WHERE request_id = $1::uuid", + requestID, + ).Scan(&count) + if err != nil { + return -1 + } + return count +} + +func (s *IntegrationSuite) getUserBalance(userID int) float64 { + var balance float64 + err := s.pool.QueryRow(s.ctx, + "SELECT balance FROM users WHERE id = $1", + userID, + ).Scan(&balance) + if err != nil { + return -1 + } + return balance +} + +func (s *IntegrationSuite) getTokenUsageCount(requestID string) int { + var count int + err := s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM request_token_usage WHERE request_id = $1::uuid", + requestID, + ).Scan(&count) + if err != nil { + return -1 + } + return count +} + +func (s *IntegrationSuite) createUniqueTestUser(suffix string, balance float64) (email string, password string, userID int) { + email = fmt.Sprintf("user_%s_%d@example.com", suffix, time.Now().UnixNano()) + password = "testpassword" + + cryptoHelper := crypto.NewCrypto(testCryptoSecret) + + encryptedEmail, err := cryptoHelper.Encrypt(email) + s.Require().NoError(err) + + encryptedPhone, err := cryptoHelper.Encrypt(fmt.Sprintf("+1%d", time.Now().UnixNano()%10000000000)) + s.Require().NoError(err) + + encryptedUserName, err := cryptoHelper.Encrypt(fmt.Sprintf("User %s", suffix)) + s.Require().NoError(err) + + emailHash := cryptoHelper.EmailHash(email) + passwordHash := crypto.PasswordHash(password) + + query := ` + INSERT INTO users (email, email_hash, password_hash, phone, user_name, company_name, balance, payment_status, invites_issued, invites_limit) + VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10) + RETURNING id + ` + + err = s.pool.QueryRow(s.ctx, query, + encryptedEmail, + emailHash, + passwordHash, + encryptedPhone, + encryptedUserName, + "Test Company", + balance, + "active", + 0, + 10, + ).Scan(&userID) + s.Require().NoError(err) + + return email, password, userID +} + +func (s *IntegrationSuite) isInviteCodeActive(code int64) bool { + var isActive bool + err := s.pool.QueryRow(s.ctx, + "SELECT is_active FROM invite_codes WHERE code = $1", + code, + ).Scan(&isActive) + if err != nil { + return false + } + return isActive +} diff --git a/tests/invite_handler_test.go b/tests/invite_handler_test.go index 235228d..835065d 100644 --- a/tests/invite_handler_test.go +++ b/tests/invite_handler_test.go @@ -26,9 +26,9 @@ func (s *IntegrationSuite) TestInviteHandler_GenerateWithNonExistentUser() { s.Contains([]codes.Code{codes.NotFound, codes.Internal, codes.Unknown}, st.Code()) } -func (s *IntegrationSuite) TestInviteHandler_GetInfoWithInvalidCode() { +func (s *IntegrationSuite) TestInviteHandler_GetInfoWithNonExistentUser() { req := &invitepb.GetInfoRequest{ - Code: "999999999", + UserId: 999999, } resp, err := s.inviteClient.GetInfo(context.Background(), req) @@ -41,17 +41,6 @@ func (s *IntegrationSuite) TestInviteHandler_GetInfoWithInvalidCode() { s.Equal(codes.NotFound, st.Code()) } -func (s *IntegrationSuite) TestInviteHandler_GetInfoWithInvalidCodeFormat() { - req := &invitepb.GetInfoRequest{ - Code: "invalid-code", - } - - resp, err := s.inviteClient.GetInfo(context.Background(), req) - - s.Error(err) - s.Nil(resp) -} - func (s *IntegrationSuite) TestInviteHandler_GenerateAndGetInfoFlow() { ctx := context.Background() @@ -87,17 +76,14 @@ func (s *IntegrationSuite) TestInviteHandler_GenerateAndGetInfoFlow() { s.NotNil(generateResp.ExpiresAt) getInfoReq := &invitepb.GetInfoRequest{ - Code: generateResp.Code, + UserId: validateResp.UserId, } infoResp, err := s.inviteClient.GetInfo(ctx, getInfoReq) s.NoError(err) s.NotNil(infoResp) s.Equal(generateResp.Code, infoResp.Code) - s.Equal(validateResp.UserId, infoResp.UserId) s.Equal(generateResp.MaxUses, infoResp.CanBeUsedCount) - s.Equal(int32(0), infoResp.UsedCount) - s.True(infoResp.IsActive) } func (s *IntegrationSuite) TestInviteHandler_GenerateWithInvalidTTL() { diff --git a/tests/ownership_test.go b/tests/ownership_test.go new file mode 100644 index 0000000..1f5260f --- /dev/null +++ b/tests/ownership_test.go @@ -0,0 +1,214 @@ +package tests + +import ( + "context" + + authpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth" + requestpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/request" + supplierpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/supplier" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" +) + +func (s *IntegrationSuite) TestOwnership_GetMailingListByID_AnotherUsersRequest() { + ctx := context.Background() + + loginReq := &authpb.LoginRequest{ + Email: "test@example.com", + Password: "testpassword", + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + loginResp, err := s.authClient.Login(ctx, loginReq) + s.NoError(err) + + validateReq := &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + } + + validateResp, err := s.authClient.Validate(ctx, validateReq) + s.NoError(err) + user1ID := validateResp.UserId + + createTZReq := &requestpb.CreateTZRequest{ + UserId: user1ID, + RequestTxt: "Нужны поставщики для теста ownership", + } + + createTZResp, err := s.requestClient.CreateTZ(ctx, createTZReq) + s.NoError(err) + s.NotEmpty(createTZResp.RequestId) + requestID := createTZResp.RequestId + + _, _, user2ID := s.createSecondTestUser() + + getMailingByIDReq := &requestpb.GetMailingListByIDRequest{ + RequestId: requestID, + UserId: user2ID, + } + + resp, err := s.requestClient.GetMailingListByID(ctx, getMailingByIDReq) + s.Error(err) + s.Nil(resp) + + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.PermissionDenied, st.Code()) +} + +func (s *IntegrationSuite) TestOwnership_ApproveTZ_AnotherUsersRequest() { + ctx := context.Background() + + loginReq := &authpb.LoginRequest{ + Email: "test@example.com", + Password: "testpassword", + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + loginResp, err := s.authClient.Login(ctx, loginReq) + s.NoError(err) + + validateReq := &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + } + + validateResp, err := s.authClient.Validate(ctx, validateReq) + s.NoError(err) + user1ID := validateResp.UserId + + createTZReq := &requestpb.CreateTZRequest{ + UserId: user1ID, + RequestTxt: "Нужны поставщики для теста ownership approve", + } + + createTZResp, err := s.requestClient.CreateTZ(ctx, createTZReq) + s.NoError(err) + s.NotEmpty(createTZResp.RequestId) + requestID := createTZResp.RequestId + + _, _, user2ID := s.createSecondTestUser() + + approveTZReq := &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: "Утвержденное ТЗ от чужого пользователя", + UserId: user2ID, + } + + resp, err := s.requestClient.ApproveTZ(ctx, approveTZReq) + s.Error(err) + s.Nil(resp) + + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.PermissionDenied, st.Code()) +} + +func (s *IntegrationSuite) TestOwnership_ExportExcel_AnotherUsersRequest() { + ctx := context.Background() + + loginReq := &authpb.LoginRequest{ + Email: "test@example.com", + Password: "testpassword", + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + loginResp, err := s.authClient.Login(ctx, loginReq) + s.NoError(err) + + validateReq := &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + } + + validateResp, err := s.authClient.Validate(ctx, validateReq) + s.NoError(err) + user1ID := validateResp.UserId + + createTZReq := &requestpb.CreateTZRequest{ + UserId: user1ID, + RequestTxt: "Нужны поставщики для теста ownership export", + } + + createTZResp, err := s.requestClient.CreateTZ(ctx, createTZReq) + s.NoError(err) + s.NotEmpty(createTZResp.RequestId) + requestID := createTZResp.RequestId + + approveTZReq := &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: "Утвержденное ТЗ для экспорта", + UserId: user1ID, + } + + _, err = s.requestClient.ApproveTZ(ctx, approveTZReq) + s.NoError(err) + + _, _, user2ID := s.createSecondTestUser() + + exportReq := &supplierpb.ExportExcelRequest{ + RequestId: requestID, + UserId: user2ID, + } + + resp, err := s.supplierClient.ExportExcel(ctx, exportReq) + s.Error(err) + s.Nil(resp) + + st, ok := status.FromError(err) + s.True(ok) + s.Equal(codes.PermissionDenied, st.Code()) +} + +func (s *IntegrationSuite) TestOwnership_GetMailingListByID_OwnRequest_Success() { + ctx := context.Background() + + loginReq := &authpb.LoginRequest{ + Email: "test@example.com", + Password: "testpassword", + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + loginResp, err := s.authClient.Login(ctx, loginReq) + s.NoError(err) + + validateReq := &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + } + + validateResp, err := s.authClient.Validate(ctx, validateReq) + s.NoError(err) + userID := validateResp.UserId + + createTZReq := &requestpb.CreateTZRequest{ + UserId: userID, + RequestTxt: "Нужны поставщики для теста ownership success", + } + + createTZResp, err := s.requestClient.CreateTZ(ctx, createTZReq) + s.NoError(err) + s.NotEmpty(createTZResp.RequestId) + requestID := createTZResp.RequestId + + approveTZReq := &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: "Утвержденное ТЗ", + UserId: userID, + } + + _, err = s.requestClient.ApproveTZ(ctx, approveTZReq) + s.NoError(err) + + getMailingByIDReq := &requestpb.GetMailingListByIDRequest{ + RequestId: requestID, + UserId: userID, + } + + resp, err := s.requestClient.GetMailingListByID(ctx, getMailingByIDReq) + s.NoError(err) + s.NotNil(resp) + s.NotNil(resp.Detail) + s.Equal(requestID, resp.Detail.RequestId) +} diff --git a/tests/repository_test.go b/tests/repository_test.go index 0476d92..7ad8029 100644 --- a/tests/repository_test.go +++ b/tests/repository_test.go @@ -11,7 +11,7 @@ import ( "github.com/google/uuid" ) -func (s *IntegrationSuite) TestRepository_InviteIncrementUsedCount() { +func (s *IntegrationSuite) TestRepository_InviteDecrementCanBeUsedCount() { inviteRepo := repository.NewInviteRepository(s.pool) ctx := context.Background() @@ -28,12 +28,19 @@ func (s *IntegrationSuite) TestRepository_InviteIncrementUsedCount() { err = inviteRepo.Create(ctx, invite) s.Require().NoError(err) - err = inviteRepo.IncrementUsedCount(ctx, invite.Code) + tx, err := s.pool.Begin(ctx) + s.Require().NoError(err) + defer func() { _ = tx.Rollback(ctx) }() + + err = inviteRepo.DecrementCanBeUsedCountTx(ctx, tx, invite.Code) + s.NoError(err) + + err = tx.Commit(ctx) s.NoError(err) found, err := inviteRepo.FindByCode(ctx, invite.Code) s.NoError(err) - s.Equal(1, found.UsedCount) + s.Equal(9, found.CanBeUsedCount) } func (s *IntegrationSuite) TestRepository_InviteDeactivateExpired() { @@ -78,6 +85,43 @@ func (s *IntegrationSuite) TestRepository_InviteGetUserInvites() { s.GreaterOrEqual(len(invites), 1) } +func (s *IntegrationSuite) TestRepository_InviteFindActiveByUserID() { + inviteRepo := repository.NewInviteRepository(s.pool) + ctx := context.Background() + + var userID int + err := s.pool.QueryRow(ctx, "SELECT id FROM users LIMIT 1").Scan(&userID) + s.Require().NoError(err) + + _, err = s.pool.Exec(ctx, "UPDATE invite_codes SET is_active = false WHERE user_id = $1", userID) + s.Require().NoError(err) + + invite := &model.InviteCode{ + UserID: userID, + Code: time.Now().UnixNano(), + CanBeUsedCount: 5, + ExpiresAt: time.Now().Add(24 * time.Hour), + } + err = inviteRepo.Create(ctx, invite) + s.Require().NoError(err) + + found, err := inviteRepo.FindActiveByUserID(ctx, userID) + s.NoError(err) + s.NotNil(found) + s.Equal(invite.Code, found.Code) + s.Equal(userID, found.UserID) + s.True(found.IsActive) +} + +func (s *IntegrationSuite) TestRepository_InviteFindActiveByUserIDNotFound() { + inviteRepo := repository.NewInviteRepository(s.pool) + ctx := context.Background() + + found, err := inviteRepo.FindActiveByUserID(ctx, 999999) + s.Error(err) + s.Nil(found) +} + func (s *IntegrationSuite) TestRepository_SessionRevoke() { sessionRepo := repository.NewSessionRepository(s.pool) ctx := context.Background() @@ -199,6 +243,42 @@ func (s *IntegrationSuite) TestRepository_SupplierBulkInsertAndDelete() { s.Equal(0, len(found)) } +func (s *IntegrationSuite) TestRepository_SupplierBulkInsertReturnsIDs() { + supplierRepo := repository.NewSupplierRepository(s.pool) + requestRepo := repository.NewRequestRepository(s.pool) + ctx := context.Background() + + var userID int + err := s.pool.QueryRow(ctx, "SELECT id FROM users LIMIT 1").Scan(&userID) + s.Require().NoError(err) + + req := &model.Request{ + UserID: userID, + RequestTxt: "Test request for supplier IDs", + } + err = requestRepo.Create(ctx, req) + s.Require().NoError(err) + + suppliers := []*model.Supplier{ + {Name: "Supplier A", Email: "a@test.com", Phone: "+7111"}, + {Name: "Supplier B", Email: "b@test.com", Phone: "+7222"}, + {Name: "Supplier C", Email: "c@test.com", Phone: "+7333"}, + } + + err = supplierRepo.BulkInsert(ctx, req.ID, suppliers) + s.NoError(err) + + for i, sup := range suppliers { + s.NotZero(sup.ID, "Supplier %d should have non-zero ID after BulkInsert", i) + s.Equal(req.ID, sup.RequestID, "Supplier %d should have correct RequestID", i) + } + + s.NotEqual(suppliers[0].ID, suppliers[1].ID, "Suppliers should have different IDs") + s.NotEqual(suppliers[1].ID, suppliers[2].ID, "Suppliers should have different IDs") + + _ = supplierRepo.DeleteByRequestID(ctx, req.ID) +} + func (s *IntegrationSuite) TestRepository_TokenUsageCreate() { tokenRepo := repository.NewTokenUsageRepository(s.pool) requestRepo := repository.NewRequestRepository(s.pool) @@ -226,6 +306,64 @@ func (s *IntegrationSuite) TestRepository_TokenUsageCreate() { s.NoError(err) } +func (s *IntegrationSuite) TestRepository_TokenUsageGetBalanceStatistics() { + tokenRepo := repository.NewTokenUsageRepository(s.pool) + requestRepo := repository.NewRequestRepository(s.pool) + ctx := context.Background() + + var userID int + err := s.pool.QueryRow(ctx, "SELECT id FROM users LIMIT 1").Scan(&userID) + s.Require().NoError(err) + + req := &model.Request{ + UserID: userID, + RequestTxt: "Test request for balance statistics", + } + err = requestRepo.Create(ctx, req) + s.Require().NoError(err) + + usage1 := &model.TokenUsage{ + RequestID: req.ID, + RequestTokenCount: 100, + ResponseTokenCount: 50, + TokenCost: 10.0, + Type: "openai", + } + err = tokenRepo.Create(ctx, usage1) + s.Require().NoError(err) + + usage2 := &model.TokenUsage{ + RequestID: req.ID, + RequestTokenCount: 200, + ResponseTokenCount: 100, + TokenCost: 20.0, + Type: "perplexity", + } + err = tokenRepo.Create(ctx, usage2) + s.Require().NoError(err) + + averageCost, history, err := tokenRepo.GetBalanceStatistics(ctx, userID) + s.NoError(err) + s.Greater(averageCost, 0.0) + s.GreaterOrEqual(len(history), 2) + + for _, item := range history { + s.NotEmpty(item.OperationID) + s.NotEmpty(item.Data) + s.GreaterOrEqual(item.Amount, 0.0) + } +} + +func (s *IntegrationSuite) TestRepository_TokenUsageGetBalanceStatisticsEmpty() { + tokenRepo := repository.NewTokenUsageRepository(s.pool) + ctx := context.Background() + + averageCost, history, err := tokenRepo.GetBalanceStatistics(ctx, 999999) + s.NoError(err) + s.Equal(0.0, averageCost) + s.Empty(history) +} + func (s *IntegrationSuite) TestRepository_UserCreate() { userRepo := repository.NewUserRepository(s.pool, testCryptoSecret) ctx := context.Background() diff --git a/tests/request_handler_test.go b/tests/request_handler_test.go index b9f9cd0..150ecb1 100644 --- a/tests/request_handler_test.go +++ b/tests/request_handler_test.go @@ -136,3 +136,133 @@ func (s *IntegrationSuite) TestRequestHandler_GetMailingListWithValidUser() { s.NoError(err) s.NotNil(resp) } + +func (s *IntegrationSuite) TestRequestHandler_CreateTZWithFile() { + ctx := context.Background() + + loginReq := &authpb.LoginRequest{ + Email: "test@example.com", + Password: "testpassword", + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + loginResp, err := s.authClient.Login(ctx, loginReq) + s.NoError(err) + + validateReq := &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + } + + validateResp, err := s.authClient.Validate(ctx, validateReq) + s.NoError(err) + + req := &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: "Нужны поставщики металлоконструкций", + FileData: []byte("Содержимое файла с дополнительными требованиями"), + FileName: "requirements.txt", + } + + resp, err := s.requestClient.CreateTZ(ctx, req) + s.NoError(err) + s.NotNil(resp) + s.NotEmpty(resp.RequestId) + s.NotEmpty(resp.TzText) +} + +func (s *IntegrationSuite) TestRequestHandler_ApproveTZSuccess() { + ctx := context.Background() + + loginReq := &authpb.LoginRequest{ + Email: "test@example.com", + Password: "testpassword", + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + loginResp, err := s.authClient.Login(ctx, loginReq) + s.NoError(err) + + validateReq := &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + } + + validateResp, err := s.authClient.Validate(ctx, validateReq) + s.NoError(err) + userID := validateResp.UserId + + createReq := &requestpb.CreateTZRequest{ + UserId: userID, + RequestTxt: "Нужны поставщики кирпича для строительства", + } + + createResp, err := s.requestClient.CreateTZ(ctx, createReq) + s.NoError(err) + s.NotEmpty(createResp.RequestId) + + approveReq := &requestpb.ApproveTZRequest{ + RequestId: createResp.RequestId, + FinalTz: "Утвержденное техническое задание на поставку кирпича", + UserId: userID, + } + + approveResp, err := s.requestClient.ApproveTZ(ctx, approveReq) + s.NoError(err) + s.NotNil(approveResp) + s.NotEmpty(approveResp.RequestId) + s.NotNil(approveResp.Suppliers) +} + +func (s *IntegrationSuite) TestRequestHandler_GetMailingListByIDSuccess() { + ctx := context.Background() + + loginReq := &authpb.LoginRequest{ + Email: "test@example.com", + Password: "testpassword", + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + loginResp, err := s.authClient.Login(ctx, loginReq) + s.NoError(err) + + validateReq := &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + } + + validateResp, err := s.authClient.Validate(ctx, validateReq) + s.NoError(err) + userID := validateResp.UserId + + createReq := &requestpb.CreateTZRequest{ + UserId: userID, + RequestTxt: "Нужны поставщики бетона", + } + + createResp, err := s.requestClient.CreateTZ(ctx, createReq) + s.NoError(err) + s.NotEmpty(createResp.RequestId) + requestID := createResp.RequestId + + approveReq := &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: "Утвержденное ТЗ на поставку бетона", + UserId: userID, + } + + _, err = s.requestClient.ApproveTZ(ctx, approveReq) + s.NoError(err) + + getByIDReq := &requestpb.GetMailingListByIDRequest{ + RequestId: requestID, + UserId: userID, + } + + getByIDResp, err := s.requestClient.GetMailingListByID(ctx, getByIDReq) + s.NoError(err) + s.NotNil(getByIDResp) + s.NotNil(getByIDResp.Detail) + s.Equal(requestID, getByIDResp.Detail.RequestId) + s.Greater(len(getByIDResp.Detail.Suppliers), 0) +} diff --git a/tests/statistics_test.go b/tests/statistics_test.go new file mode 100644 index 0000000..fc77ba8 --- /dev/null +++ b/tests/statistics_test.go @@ -0,0 +1,131 @@ +package tests + +import ( + "context" + + "git.techease.ru/Smart-search/smart-search-back/internal/model" + "git.techease.ru/Smart-search/smart-search-back/internal/repository" + userpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/user" +) + +func (s *IntegrationSuite) TestStatistics_CorrectCountWithMultipleUsers() { + ctx := context.Background() + + _, _, user1ID := s.createUniqueTestUser("stats_user1", 1000.0) + _, _, user2ID := s.createUniqueTestUser("stats_user2", 1000.0) + + requestRepo := repository.NewRequestRepository(s.pool) + supplierRepo := repository.NewSupplierRepository(s.pool) + + req1 := &model.Request{UserID: user1ID, RequestTxt: "Request 1 with TZ"} + err := requestRepo.Create(ctx, req1) + s.Require().NoError(err) + + req2 := &model.Request{UserID: user1ID, RequestTxt: "Request 2 with TZ"} + err = requestRepo.Create(ctx, req2) + s.Require().NoError(err) + + _, err = s.pool.Exec(ctx, ` + INSERT INTO requests_for_suppliers (user_id, request_txt, mailling_status_id) + VALUES ($1, NULL, 1) + `, user1ID) + s.Require().NoError(err) + + suppliers1 := []*model.Supplier{ + {Name: "Supplier 1", Email: "s1@test.com"}, + {Name: "Supplier 2", Email: "s2@test.com"}, + {Name: "Supplier 3", Email: "s3@test.com"}, + } + err = supplierRepo.BulkInsert(ctx, req1.ID, suppliers1) + s.Require().NoError(err) + + suppliers2 := []*model.Supplier{ + {Name: "Supplier 4", Email: "s4@test.com"}, + {Name: "Supplier 5", Email: "s5@test.com"}, + } + err = supplierRepo.BulkInsert(ctx, req2.ID, suppliers2) + s.Require().NoError(err) + + req4 := &model.Request{UserID: user2ID, RequestTxt: "User2 Request with TZ"} + err = requestRepo.Create(ctx, req4) + s.Require().NoError(err) + + suppliers3 := []*model.Supplier{ + {Name: "Supplier 6", Email: "s6@test.com"}, + {Name: "Supplier 7", Email: "s7@test.com"}, + {Name: "Supplier 8", Email: "s8@test.com"}, + {Name: "Supplier 9", Email: "s9@test.com"}, + {Name: "Supplier 10", Email: "s10@test.com"}, + } + err = supplierRepo.BulkInsert(ctx, req4.ID, suppliers3) + s.Require().NoError(err) + + resp, err := s.userClient.GetStatistics(ctx, &userpb.GetStatisticsRequest{ + UserId: int64(user1ID), + }) + s.NoError(err) + s.NotNil(resp) + + s.Equal("3", resp.RequestsCount) + s.Equal("5", resp.SuppliersCount) + s.Equal("2", resp.CreatedTz) + + resp2, err := s.userClient.GetStatistics(ctx, &userpb.GetStatisticsRequest{ + UserId: int64(user2ID), + }) + s.NoError(err) + s.NotNil(resp2) + + s.Equal("1", resp2.RequestsCount) + s.Equal("5", resp2.SuppliersCount) + s.Equal("1", resp2.CreatedTz) +} + +func (s *IntegrationSuite) TestStatistics_CreatedTZNotMultipliedBySuppliers() { + ctx := context.Background() + + _, _, userID := s.createUniqueTestUser("stats_multiply", 1000.0) + + requestRepo := repository.NewRequestRepository(s.pool) + supplierRepo := repository.NewSupplierRepository(s.pool) + + req := &model.Request{UserID: userID, RequestTxt: "Single request with TZ"} + err := requestRepo.Create(ctx, req) + s.Require().NoError(err) + + suppliers := make([]*model.Supplier, 10) + for i := 0; i < 10; i++ { + suppliers[i] = &model.Supplier{ + Name: "Supplier", + Email: "supplier@test.com", + } + } + err = supplierRepo.BulkInsert(ctx, req.ID, suppliers) + s.Require().NoError(err) + + resp, err := s.userClient.GetStatistics(ctx, &userpb.GetStatisticsRequest{ + UserId: int64(userID), + }) + s.NoError(err) + s.NotNil(resp) + + s.Equal("1", resp.RequestsCount) + s.Equal("10", resp.SuppliersCount) + s.Equal("1", resp.CreatedTz) +} + +func (s *IntegrationSuite) TestStatistics_EmptyForNewUser() { + ctx := context.Background() + + _, _, userID := s.createUniqueTestUser("stats_empty", 1000.0) + + resp, err := s.userClient.GetStatistics(ctx, &userpb.GetStatisticsRequest{ + UserId: int64(userID), + }) + s.NoError(err) + s.NotNil(resp) + + s.Equal("0", resp.RequestsCount) + s.Equal("0", resp.SuppliersCount) + s.Equal("0", resp.CreatedTz) +} diff --git a/tests/supplier_handler_test.go b/tests/supplier_handler_test.go index 0ff0ed8..f2b779f 100644 --- a/tests/supplier_handler_test.go +++ b/tests/supplier_handler_test.go @@ -93,3 +93,58 @@ func (s *IntegrationSuite) TestSupplierHandler_ExportExcelWithValidRequest() { s.NotEmpty(exportResp.FileName) s.Equal("application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", exportResp.MimeType) } + +func (s *IntegrationSuite) TestSupplierHandler_ExportExcelWithSuppliers() { + ctx := context.Background() + + loginReq := &authpb.LoginRequest{ + Email: "test@example.com", + Password: "testpassword", + Ip: "127.0.0.1", + UserAgent: "integration-test", + } + + loginResp, err := s.authClient.Login(ctx, loginReq) + s.NoError(err) + + validateReq := &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + } + + validateResp, err := s.authClient.Validate(ctx, validateReq) + s.NoError(err) + userID := validateResp.UserId + + createReq := &requestpb.CreateTZRequest{ + UserId: userID, + RequestTxt: "Нужны поставщики офисной мебели для большого офиса", + } + + createResp, err := s.requestClient.CreateTZ(ctx, createReq) + s.NoError(err) + s.NotEmpty(createResp.RequestId) + requestID := createResp.RequestId + + approveReq := &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: "Техническое задание на поставку офисной мебели", + UserId: userID, + } + + approveResp, err := s.requestClient.ApproveTZ(ctx, approveReq) + s.NoError(err) + s.NotEmpty(approveResp.RequestId) + + exportReq := &supplierpb.ExportExcelRequest{ + RequestId: requestID, + UserId: userID, + } + + exportResp, err := s.supplierClient.ExportExcel(ctx, exportReq) + s.NoError(err) + s.NotNil(exportResp) + s.NotEmpty(exportResp.FileName) + s.NotEmpty(exportResp.FileData) + s.Equal("application/vnd.openxmlformats-officedocument.spreadsheetml.sheet", exportResp.MimeType) + s.Greater(len(exportResp.FileData), 1000) +} diff --git a/tests/transaction_rollback_test.go b/tests/transaction_rollback_test.go new file mode 100644 index 0000000..cde240a --- /dev/null +++ b/tests/transaction_rollback_test.go @@ -0,0 +1,194 @@ +package tests + +import ( + "sync" + "sync/atomic" + + authpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/auth" + requestpb "git.techease.ru/Smart-search/smart-search-back/pkg/pb/request" +) + +func (s *IntegrationSuite) TestTransaction_CreateTZ_InsufficientBalance_Rollback() { + email, password, userID := s.createUniqueTestUser("insufficient_tz", 0.001) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + }) + s.Require().NoError(err) + + initialBalance := s.getUserBalance(userID) + s.T().Logf("Initial balance: %.4f", initialBalance) + + _, err = s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: "Тест с недостаточным балансом", + }) + + if err != nil { + s.T().Logf("CreateTZ failed as expected: %v", err) + + finalBalance := s.getUserBalance(userID) + s.T().Logf("Final balance after failed CreateTZ: %.4f", finalBalance) + + s.GreaterOrEqual(finalBalance, 0.0, + "Баланс не должен быть отрицательным после rollback") + } +} + +func (s *IntegrationSuite) TestTransaction_ApproveTZ_InsufficientBalance_NoSuppliers() { + email, password, userID := s.createUniqueTestUser("approve_insufficient", 1000.0) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + }) + s.Require().NoError(err) + + createResp, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: "Тест approve с недостаточным балансом", + }) + s.Require().NoError(err) + requestID := createResp.RequestId + + _, err = s.pool.Exec(s.ctx, "UPDATE users SET balance = 0.001 WHERE id = $1", userID) + s.Require().NoError(err) + + suppliersBeforeApprove := s.getRequestSuppliersCount(requestID) + s.T().Logf("Suppliers before ApproveTZ: %d", suppliersBeforeApprove) + + _, err = s.requestClient.ApproveTZ(s.ctx, &requestpb.ApproveTZRequest{ + RequestId: requestID, + FinalTz: "Утвержденное ТЗ", + UserId: validateResp.UserId, + }) + + if err != nil { + s.T().Logf("ApproveTZ failed as expected: %v", err) + + suppliersAfterApprove := s.getRequestSuppliersCount(requestID) + s.T().Logf("Suppliers after failed ApproveTZ: %d", suppliersAfterApprove) + + finalBalance := s.getUserBalance(userID) + s.GreaterOrEqual(finalBalance, 0.0, + "Баланс не должен быть отрицательным") + } +} + +func (s *IntegrationSuite) TestTransaction_ConcurrentCreateTZ_BalanceAtomicity() { + email, password, userID := s.createUniqueTestUser("concurrent_tz", 100.0) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + }) + s.Require().NoError(err) + + var wg sync.WaitGroup + var successCount int32 + var errorCount int32 + goroutines := 10 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func(idx int) { + defer wg.Done() + + <-startBarrier + + _, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: "Параллельный тест CreateTZ", + }) + + if err == nil { + atomic.AddInt32(&successCount, 1) + } else { + atomic.AddInt32(&errorCount, 1) + } + }(i) + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Concurrent CreateTZ - Success: %d, Errors: %d", successCount, errorCount) + + finalBalance := s.getUserBalance(userID) + s.T().Logf("Final balance: %.4f", finalBalance) + + s.GreaterOrEqual(finalBalance, 0.0, + "Баланс не должен быть отрицательным после параллельных операций") +} + +func (s *IntegrationSuite) TestTransaction_TokenUsage_BalanceConsistency() { + email, password, userID := s.createUniqueTestUser("token_consistency", 1000.0) + + initialBalance := s.getUserBalance(userID) + s.T().Logf("Initial balance: %.4f", initialBalance) + + loginResp, err := s.authClient.Login(s.ctx, &authpb.LoginRequest{ + Email: email, + Password: password, + Ip: "127.0.0.1", + UserAgent: "test-agent", + }) + s.Require().NoError(err) + + validateResp, err := s.authClient.Validate(s.ctx, &authpb.ValidateRequest{ + AccessToken: loginResp.AccessToken, + }) + s.Require().NoError(err) + + createResp, err := s.requestClient.CreateTZ(s.ctx, &requestpb.CreateTZRequest{ + UserId: validateResp.UserId, + RequestTxt: "Тест consistency token_usage и balance", + }) + s.Require().NoError(err) + requestID := createResp.RequestId + + tokenUsageCount := s.getTokenUsageCount(requestID) + s.T().Logf("Token usage records for request: %d", tokenUsageCount) + + finalBalance := s.getUserBalance(userID) + balanceDelta := initialBalance - finalBalance + s.T().Logf("Balance delta: %.4f", balanceDelta) + + if tokenUsageCount > 0 { + s.Greater(balanceDelta, 0.0, + "Баланс должен уменьшиться при наличии token_usage записей") + } + + var totalTokenCost float64 + err = s.pool.QueryRow(s.ctx, + "SELECT COALESCE(SUM(token_cost), 0) FROM request_token_usage WHERE request_id = $1::uuid", + requestID, + ).Scan(&totalTokenCost) + s.NoError(err) + + s.T().Logf("Total token cost from DB: %.4f, Balance delta: %.4f", totalTokenCost, balanceDelta) +} diff --git a/tests/user_handler_test.go b/tests/user_handler_test.go index dbf90d9..a1dea52 100644 --- a/tests/user_handler_test.go +++ b/tests/user_handler_test.go @@ -54,7 +54,7 @@ func (s *IntegrationSuite) TestUserHandler_GetStatisticsWithNonExistentUser() { } s.NotNil(resp) - s.Equal(int32(0), resp.TotalRequests) + s.Equal("0", resp.RequestsCount) } func (s *IntegrationSuite) TestUserHandler_GetBalanceStatistics() { @@ -72,8 +72,7 @@ func (s *IntegrationSuite) TestUserHandler_GetBalanceStatistics() { } s.NotNil(resp) - s.GreaterOrEqual(resp.Balance, 0.0) - s.GreaterOrEqual(resp.TotalRequests, int32(0)) + s.GreaterOrEqual(resp.AverageCost, 0.0) } func (s *IntegrationSuite) TestUserHandler_GetInfoWithValidUser() { @@ -167,8 +166,7 @@ func (s *IntegrationSuite) TestUserHandler_GetStatisticsWithValidUser() { resp, err := s.userClient.GetStatistics(ctx, req) s.NoError(err) s.NotNil(resp) - s.GreaterOrEqual(resp.TotalRequests, int32(0)) - s.GreaterOrEqual(resp.SuccessfulRequests, int32(0)) - s.GreaterOrEqual(resp.FailedRequests, int32(0)) - s.GreaterOrEqual(resp.TotalSpent, 0.0) + s.NotEmpty(resp.SuppliersCount) + s.NotEmpty(resp.RequestsCount) + s.NotEmpty(resp.CreatedTz) } diff --git a/tests/worker_concurrent_test.go b/tests/worker_concurrent_test.go new file mode 100644 index 0000000..53e410a --- /dev/null +++ b/tests/worker_concurrent_test.go @@ -0,0 +1,181 @@ +package tests + +import ( + "fmt" + "sync" + "sync/atomic" + "time" + + "github.com/google/uuid" +) + +func (s *IntegrationSuite) TestWorkerConcurrent_SessionCleanup_MassExpired() { + _, _, userID := s.createUniqueTestUser("session_cleanup", 100.0) + + expiredTime := time.Now().Add(-24 * time.Hour) + validTime := time.Now().Add(24 * time.Hour) + + expiredCount := 100 + validCount := 50 + + for i := 0; i < expiredCount; i++ { + _, err := s.pool.Exec(s.ctx, ` + INSERT INTO sessions (user_id, access_token, refresh_token, ip, user_agent, expires_at) + VALUES ($1, $2, $3, '127.0.0.1', 'test-agent', $4) + `, userID, + fmt.Sprintf("expired_access_%d_%s", i, uuid.New().String()), + fmt.Sprintf("expired_refresh_%d_%s", i, uuid.New().String()), + expiredTime, + ) + s.Require().NoError(err) + } + + for i := 0; i < validCount; i++ { + _, err := s.pool.Exec(s.ctx, ` + INSERT INTO sessions (user_id, access_token, refresh_token, ip, user_agent, expires_at) + VALUES ($1, $2, $3, '127.0.0.1', 'test-agent', $4) + `, userID, + fmt.Sprintf("valid_access_%d_%s", i, uuid.New().String()), + fmt.Sprintf("valid_refresh_%d_%s", i, uuid.New().String()), + validTime, + ) + s.Require().NoError(err) + } + + var totalBefore int + err := s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM sessions WHERE user_id = $1", userID, + ).Scan(&totalBefore) + s.Require().NoError(err) + s.T().Logf("Sessions before cleanup: %d", totalBefore) + + var wg sync.WaitGroup + var totalDeleted int32 + goroutines := 10 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-startBarrier + + result, err := s.pool.Exec(s.ctx, ` + DELETE FROM sessions + WHERE expires_at < now() + OR (revoked_at IS NOT NULL AND revoked_at < now() - interval '30 days') + `) + + if err == nil { + atomic.AddInt32(&totalDeleted, int32(result.RowsAffected())) + } + }() + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Total deleted by concurrent cleanup: %d", totalDeleted) + + var validRemaining int + err = s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM sessions WHERE user_id = $1 AND expires_at > now()", userID, + ).Scan(&validRemaining) + s.NoError(err) + + s.T().Logf("Valid sessions remaining: %d (expected: %d)", validRemaining, validCount) + + s.Equal(validCount, validRemaining, + "Все валидные сессии должны остаться после cleanup") + + s.GreaterOrEqual(int(totalDeleted), expiredCount, + "Все истекшие сессии должны быть удалены") +} + +func (s *IntegrationSuite) TestWorkerConcurrent_InviteCleanup_MassExpired() { + _, _, userID := s.createUniqueTestUser("invite_cleanup", 100.0) + + _, err := s.pool.Exec(s.ctx, "UPDATE users SET invites_limit = 200 WHERE id = $1", userID) + s.Require().NoError(err) + + expiredTime := time.Now().Add(-24 * time.Hour) + validTime := time.Now().Add(24 * time.Hour) + + expiredCount := 100 + validCount := 50 + + for i := 0; i < expiredCount; i++ { + code := int64(30000000 + i) + _, err := s.pool.Exec(s.ctx, ` + INSERT INTO invite_codes (user_id, code, can_be_used_count, expires_at, is_active) + VALUES ($1, $2, 5, $3, true) + `, userID, code, expiredTime) + s.Require().NoError(err) + } + + for i := 0; i < validCount; i++ { + code := int64(40000000 + i) + _, err := s.pool.Exec(s.ctx, ` + INSERT INTO invite_codes (user_id, code, can_be_used_count, expires_at, is_active) + VALUES ($1, $2, 5, $3, true) + `, userID, code, validTime) + s.Require().NoError(err) + } + + var activeBefore int + err = s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM invite_codes WHERE user_id = $1 AND is_active = true", userID, + ).Scan(&activeBefore) + s.Require().NoError(err) + s.T().Logf("Active invites before cleanup: %d", activeBefore) + + var wg sync.WaitGroup + var totalDeactivated int32 + goroutines := 10 + + startBarrier := make(chan struct{}) + + for i := 0; i < goroutines; i++ { + wg.Add(1) + go func() { + defer wg.Done() + <-startBarrier + + result, err := s.pool.Exec(s.ctx, ` + UPDATE invite_codes + SET is_active = false + WHERE expires_at < now() AND is_active = true + `) + + if err == nil { + atomic.AddInt32(&totalDeactivated, int32(result.RowsAffected())) + } + }() + } + + close(startBarrier) + wg.Wait() + + s.T().Logf("Total deactivated by concurrent cleanup: %d", totalDeactivated) + + var activeRemaining int + err = s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM invite_codes WHERE user_id = $1 AND is_active = true", userID, + ).Scan(&activeRemaining) + s.NoError(err) + + s.T().Logf("Active invites remaining: %d (expected: %d)", activeRemaining, validCount) + + s.Equal(validCount, activeRemaining, + "Все валидные инвайты должны остаться активными после cleanup") + + var expiredStillActive int + err = s.pool.QueryRow(s.ctx, + "SELECT COUNT(*) FROM invite_codes WHERE user_id = $1 AND expires_at < now() AND is_active = true", userID, + ).Scan(&expiredStillActive) + s.NoError(err) + + s.Equal(0, expiredStillActive, + "Не должно остаться активных истекших инвайтов") +}