package repository import ( "context" "errors" "git.techease.ru/Smart-search/smart-search-back/internal/model" errs "git.techease.ru/Smart-search/smart-search-back/pkg/errors" sq "github.com/Masterminds/squirrel" "github.com/jackc/pgx/v5" "github.com/jackc/pgx/v5/pgxpool" ) type inviteRepository struct { pool *pgxpool.Pool qb sq.StatementBuilderType } func NewInviteRepository(pool *pgxpool.Pool) InviteRepository { return &inviteRepository{ pool: pool, qb: sq.StatementBuilder.PlaceholderFormat(sq.Dollar), } } func (r *inviteRepository) Create(ctx context.Context, invite *model.InviteCode) error { return r.createWithExecutor(ctx, r.pool, invite) } func (r *inviteRepository) CreateTx(ctx context.Context, tx pgx.Tx, invite *model.InviteCode) error { return r.createWithExecutor(ctx, tx, invite) } func (r *inviteRepository) createWithExecutor(ctx context.Context, exec DBTX, invite *model.InviteCode) error { query := r.qb.Insert("invite_codes").Columns( "user_id", "code", "can_be_used_count", "expires_at", ).Values( invite.UserID, invite.Code, invite.CanBeUsedCount, invite.ExpiresAt, ).Suffix("RETURNING id, created_at") sqlQuery, args, err := query.ToSql() if err != nil { return errs.NewInternalError(errs.DatabaseError, "failed to build query", err) } err = exec.QueryRow(ctx, sqlQuery, args...).Scan(&invite.ID, &invite.CreatedAt) if err != nil { return errs.NewInternalError(errs.DatabaseError, "failed to create invite code", err) } return nil } func (r *inviteRepository) FindByCode(ctx context.Context, code int64) (*model.InviteCode, error) { query := r.qb.Select( "id", "user_id", "code", "can_be_used_count", "is_active", "created_at", "expires_at", ).From("invite_codes").Where(sq.Eq{"code": code}) sqlQuery, args, err := query.ToSql() if err != nil { return nil, errs.NewInternalError(errs.DatabaseError, "failed to build query", err) } invite := &model.InviteCode{} err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan( &invite.ID, &invite.UserID, &invite.Code, &invite.CanBeUsedCount, &invite.IsActive, &invite.CreatedAt, &invite.ExpiresAt, ) if errors.Is(err, pgx.ErrNoRows) { return nil, errs.NewBusinessError(errs.UserNotFound, "invite code not found") } if err != nil { return nil, errs.NewInternalError(errs.DatabaseError, "failed to find invite code", err) } return invite, nil } func (r *inviteRepository) FindActiveByCode(ctx context.Context, code int64) (*model.InviteCode, error) { query := r.qb.Select( "id", "user_id", "code", "can_be_used_count", "is_active", "created_at", "expires_at", ).From("invite_codes").Where(sq.And{ sq.Eq{"code": code}, sq.Eq{"is_active": true}, sq.Expr("expires_at > now()"), sq.Expr("can_be_used_count > 0"), }) sqlQuery, args, err := query.ToSql() if err != nil { return nil, errs.NewInternalError(errs.DatabaseError, "failed to build query", err) } invite := &model.InviteCode{} err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan( &invite.ID, &invite.UserID, &invite.Code, &invite.CanBeUsedCount, &invite.IsActive, &invite.CreatedAt, &invite.ExpiresAt, ) if errors.Is(err, pgx.ErrNoRows) { return nil, errs.NewBusinessError(errs.InviteInvalidOrExpired, "invite code is invalid or expired") } if err != nil { return nil, errs.NewInternalError(errs.DatabaseError, "failed to find active invite code", err) } return invite, nil } func (r *inviteRepository) FindActiveByUserID(ctx context.Context, userID int) (*model.InviteCode, error) { query := r.qb.Select( "id", "user_id", "code", "can_be_used_count", "is_active", "created_at", "expires_at", ).From("invite_codes").Where(sq.And{ sq.Eq{"user_id": userID}, sq.Eq{"is_active": true}, sq.Expr("expires_at > now()"), sq.Expr("can_be_used_count > 0"), }).OrderBy("created_at DESC").Limit(1) sqlQuery, args, err := query.ToSql() if err != nil { return nil, errs.NewInternalError(errs.DatabaseError, "failed to build query", err) } invite := &model.InviteCode{} err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan( &invite.ID, &invite.UserID, &invite.Code, &invite.CanBeUsedCount, &invite.IsActive, &invite.CreatedAt, &invite.ExpiresAt, ) if errors.Is(err, pgx.ErrNoRows) { return nil, errs.NewBusinessError(errs.InviteInvalidOrExpired, "no active invite code found") } if err != nil { return nil, errs.NewInternalError(errs.DatabaseError, "failed to find active invite code by user", err) } return invite, nil } func (r *inviteRepository) DecrementCanBeUsedCountTx(ctx context.Context, tx pgx.Tx, code int64) error { query := r.qb.Update("invite_codes"). Set("can_be_used_count", sq.Expr("can_be_used_count - 1")). Set("is_active", sq.Expr("CASE WHEN can_be_used_count - 1 <= 0 THEN false ELSE is_active END")). Where(sq.And{ sq.Eq{"code": code}, sq.Expr("can_be_used_count > 0"), sq.Eq{"is_active": true}, sq.Expr("expires_at > now()"), }) sqlQuery, args, err := query.ToSql() if err != nil { return errs.NewInternalError(errs.DatabaseError, "failed to build query", err) } result, err := tx.Exec(ctx, sqlQuery, args...) if err != nil { return errs.NewInternalError(errs.DatabaseError, "failed to decrement can_be_used_count", err) } if result.RowsAffected() == 0 { return errs.NewBusinessError(errs.InviteInvalidOrExpired, "invite code is invalid, expired, or exhausted") } return nil } func (r *inviteRepository) DeactivateExpired(ctx context.Context) (int, error) { query := r.qb.Update("invite_codes"). Set("is_active", false). Where(sq.And{ sq.Expr("expires_at < now()"), sq.Eq{"is_active": true}, }) sqlQuery, args, err := query.ToSql() if err != nil { return 0, errs.NewInternalError(errs.DatabaseError, "failed to build query", err) } result, err := r.pool.Exec(ctx, sqlQuery, args...) if err != nil { return 0, errs.NewInternalError(errs.DatabaseError, "failed to deactivate expired invites", err) } return int(result.RowsAffected()), nil } func (r *inviteRepository) GetUserInvites(ctx context.Context, userID int) ([]*model.InviteCode, error) { query := r.qb.Select( "id", "user_id", "code", "can_be_used_count", "is_active", "created_at", "expires_at", ).From("invite_codes"). Where(sq.Eq{"user_id": userID}). OrderBy("created_at DESC") sqlQuery, args, err := query.ToSql() if err != nil { return nil, errs.NewInternalError(errs.DatabaseError, "failed to build query", err) } rows, err := r.pool.Query(ctx, sqlQuery, args...) if err != nil { return nil, errs.NewInternalError(errs.DatabaseError, "failed to get user invites", err) } defer rows.Close() var invites []*model.InviteCode for rows.Next() { invite := &model.InviteCode{} err := rows.Scan( &invite.ID, &invite.UserID, &invite.Code, &invite.CanBeUsedCount, &invite.IsActive, &invite.CreatedAt, &invite.ExpiresAt, ) if err != nil { return nil, errs.NewInternalError(errs.DatabaseError, "failed to scan invite", err) } invites = append(invites, invite) } return invites, nil }