200 lines
6.0 KiB
Go
200 lines
6.0 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
|
|
"smart-search-back/internal/model"
|
|
"smart-search-back/pkg/crypto"
|
|
errs "smart-search-back/pkg/errors"
|
|
|
|
sq "github.com/Masterminds/squirrel"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
type userRepository struct {
|
|
pool *pgxpool.Pool
|
|
qb sq.StatementBuilderType
|
|
cryptoHelper *crypto.Crypto
|
|
}
|
|
|
|
func NewUserRepository(pool *pgxpool.Pool, cryptoSecret string) UserRepository {
|
|
return &userRepository{
|
|
pool: pool,
|
|
qb: sq.StatementBuilder.PlaceholderFormat(sq.Dollar),
|
|
cryptoHelper: crypto.NewCrypto(cryptoSecret),
|
|
}
|
|
}
|
|
|
|
func (r *userRepository) FindByEmailHash(ctx context.Context, emailHash string) (*model.User, error) {
|
|
query := r.qb.Select(
|
|
"id", "email", "email_hash", "password_hash", "phone",
|
|
"user_name", "company_name", "balance", "payment_status",
|
|
"invites_issued", "invites_limit", "created_at",
|
|
).From("users").Where(sq.Eq{"email_hash": emailHash})
|
|
|
|
sqlQuery, args, err := query.ToSql()
|
|
if err != nil {
|
|
return nil, errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
|
}
|
|
|
|
user := &model.User{}
|
|
err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan(
|
|
&user.ID, &user.Email, &user.EmailHash, &user.PasswordHash,
|
|
&user.Phone, &user.UserName, &user.CompanyName, &user.Balance,
|
|
&user.PaymentStatus, &user.InvitesIssued, &user.InvitesLimit, &user.CreatedAt,
|
|
)
|
|
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, errs.NewBusinessError(errs.UserNotFound, "user not found")
|
|
}
|
|
if err != nil {
|
|
return nil, errs.NewInternalError(errs.DatabaseError, "failed to find user", err)
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (r *userRepository) FindByID(ctx context.Context, userID int) (*model.User, error) {
|
|
query := r.qb.Select(
|
|
"id", "email", "email_hash", "password_hash", "phone",
|
|
"user_name", "company_name", "balance", "payment_status",
|
|
"invites_issued", "invites_limit", "created_at",
|
|
).From("users").Where(sq.Eq{"id": userID})
|
|
|
|
sqlQuery, args, err := query.ToSql()
|
|
if err != nil {
|
|
return nil, errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
|
}
|
|
|
|
user := &model.User{}
|
|
err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan(
|
|
&user.ID, &user.Email, &user.EmailHash, &user.PasswordHash,
|
|
&user.Phone, &user.UserName, &user.CompanyName, &user.Balance,
|
|
&user.PaymentStatus, &user.InvitesIssued, &user.InvitesLimit, &user.CreatedAt,
|
|
)
|
|
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, errs.NewBusinessError(errs.UserNotFound, "user not found")
|
|
}
|
|
if err != nil {
|
|
return nil, errs.NewInternalError(errs.DatabaseError, "failed to find user", err)
|
|
}
|
|
|
|
return user, nil
|
|
}
|
|
|
|
func (r *userRepository) Create(ctx context.Context, user *model.User) error {
|
|
encryptedEmail, err := r.cryptoHelper.Encrypt(user.Email)
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.EncryptionError, "failed to encrypt email", err)
|
|
}
|
|
|
|
encryptedPhone, err := r.cryptoHelper.Encrypt(user.Phone)
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.EncryptionError, "failed to encrypt phone", err)
|
|
}
|
|
|
|
encryptedUserName, err := r.cryptoHelper.Encrypt(user.UserName)
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.EncryptionError, "failed to encrypt user name", err)
|
|
}
|
|
|
|
query := r.qb.Insert("users").Columns(
|
|
"email", "email_hash", "password_hash", "phone", "user_name",
|
|
"company_name", "balance", "payment_status",
|
|
).Values(
|
|
encryptedEmail, user.EmailHash, user.PasswordHash, encryptedPhone,
|
|
encryptedUserName, user.CompanyName, user.Balance, user.PaymentStatus,
|
|
).Suffix("RETURNING id")
|
|
|
|
sqlQuery, args, err := query.ToSql()
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
|
}
|
|
|
|
err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan(&user.ID)
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.DatabaseError, "failed to create user", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *userRepository) UpdateBalance(ctx context.Context, userID int, delta float64) error {
|
|
query := r.qb.Update("users").
|
|
Set("balance", sq.Expr("balance + ?", delta)).
|
|
Where(sq.Eq{"id": userID})
|
|
|
|
sqlQuery, args, err := query.ToSql()
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
|
}
|
|
|
|
_, err = r.pool.Exec(ctx, sqlQuery, args...)
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.DatabaseError, "failed to update balance", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *userRepository) GetBalance(ctx context.Context, userID int) (float64, error) {
|
|
query := r.qb.Select("balance").From("users").Where(sq.Eq{"id": userID})
|
|
|
|
sqlQuery, args, err := query.ToSql()
|
|
if err != nil {
|
|
return 0, errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
|
}
|
|
|
|
var balance float64
|
|
err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan(&balance)
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return 0, errs.NewBusinessError(errs.UserNotFound, "user not found")
|
|
}
|
|
if err != nil {
|
|
return 0, errs.NewInternalError(errs.DatabaseError, "failed to get balance", err)
|
|
}
|
|
|
|
return balance, nil
|
|
}
|
|
|
|
func (r *userRepository) IncrementInvitesIssued(ctx context.Context, userID int) error {
|
|
query := r.qb.Update("users").
|
|
Set("invites_issued", sq.Expr("invites_issued + 1")).
|
|
Where(sq.Eq{"id": userID})
|
|
|
|
sqlQuery, args, err := query.ToSql()
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
|
}
|
|
|
|
_, err = r.pool.Exec(ctx, sqlQuery, args...)
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.DatabaseError, "failed to increment invites issued", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *userRepository) CheckInviteLimit(ctx context.Context, userID int) (bool, error) {
|
|
query := r.qb.Select("invites_issued", "invites_limit").
|
|
From("users").
|
|
Where(sq.Eq{"id": userID}).
|
|
Suffix("FOR UPDATE")
|
|
|
|
sqlQuery, args, err := query.ToSql()
|
|
if err != nil {
|
|
return false, errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
|
}
|
|
|
|
var issued, limit int
|
|
err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan(&issued, &limit)
|
|
if err != nil {
|
|
return false, errs.NewInternalError(errs.DatabaseError, "failed to check invite limit", err)
|
|
}
|
|
|
|
return issued < limit, nil
|
|
}
|