This commit is contained in:
@@ -3,8 +3,9 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
|
||||
"git.techease.ru/Smart-search/smart-search-back/internal/model"
|
||||
"github.com/google/uuid"
|
||||
"smart-search-back/internal/model"
|
||||
"github.com/jackc/pgx/v5"
|
||||
)
|
||||
|
||||
type UserRepository interface {
|
||||
@@ -12,9 +13,12 @@ type UserRepository interface {
|
||||
FindByID(ctx context.Context, userID int) (*model.User, error)
|
||||
Create(ctx context.Context, user *model.User) error
|
||||
UpdateBalance(ctx context.Context, userID int, delta float64) error
|
||||
UpdateBalanceTx(ctx context.Context, tx pgx.Tx, userID int, delta float64) error
|
||||
GetBalance(ctx context.Context, userID int) (float64, error)
|
||||
IncrementInvitesIssued(ctx context.Context, userID int) error
|
||||
IncrementInvitesIssuedTx(ctx context.Context, tx pgx.Tx, userID int) error
|
||||
CheckInviteLimit(ctx context.Context, userID int) (bool, error)
|
||||
CheckInviteLimitTx(ctx context.Context, tx pgx.Tx, userID int) (bool, error)
|
||||
}
|
||||
|
||||
type SessionRepository interface {
|
||||
@@ -22,11 +26,14 @@ type SessionRepository interface {
|
||||
FindByRefreshToken(ctx context.Context, token string) (*model.Session, error)
|
||||
UpdateAccessToken(ctx context.Context, refreshToken, newAccessToken string) error
|
||||
Revoke(ctx context.Context, refreshToken string) error
|
||||
RevokeByAccessToken(ctx context.Context, accessToken string) error
|
||||
IsAccessTokenValid(ctx context.Context, accessToken string) (bool, error)
|
||||
DeleteExpired(ctx context.Context) (int, error)
|
||||
}
|
||||
|
||||
type InviteRepository interface {
|
||||
Create(ctx context.Context, invite *model.InviteCode) error
|
||||
CreateTx(ctx context.Context, tx pgx.Tx, invite *model.InviteCode) error
|
||||
FindByCode(ctx context.Context, code int64) (*model.InviteCode, error)
|
||||
IncrementUsedCount(ctx context.Context, code int64) error
|
||||
DeactivateExpired(ctx context.Context) (int, error)
|
||||
@@ -36,6 +43,7 @@ type InviteRepository interface {
|
||||
type RequestRepository interface {
|
||||
Create(ctx context.Context, req *model.Request) error
|
||||
UpdateWithTZ(ctx context.Context, id uuid.UUID, tz string, generated bool) error
|
||||
UpdateWithTZTx(ctx context.Context, tx pgx.Tx, id uuid.UUID, tz string, generated bool) error
|
||||
UpdateFinalTZ(ctx context.Context, id uuid.UUID, finalTZ string) error
|
||||
GetByUserID(ctx context.Context, userID int) ([]*model.Request, error)
|
||||
GetByID(ctx context.Context, id uuid.UUID) (*model.Request, error)
|
||||
@@ -45,10 +53,12 @@ type RequestRepository interface {
|
||||
|
||||
type SupplierRepository interface {
|
||||
BulkInsert(ctx context.Context, requestID uuid.UUID, suppliers []*model.Supplier) error
|
||||
BulkInsertTx(ctx context.Context, tx pgx.Tx, requestID uuid.UUID, suppliers []*model.Supplier) error
|
||||
GetByRequestID(ctx context.Context, requestID uuid.UUID) ([]*model.Supplier, error)
|
||||
DeleteByRequestID(ctx context.Context, requestID uuid.UUID) error
|
||||
}
|
||||
|
||||
type TokenUsageRepository interface {
|
||||
Create(ctx context.Context, usage *model.TokenUsage) error
|
||||
CreateTx(ctx context.Context, tx pgx.Tx, usage *model.TokenUsage) error
|
||||
}
|
||||
|
||||
@@ -4,9 +4,8 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"smart-search-back/internal/model"
|
||||
errs "smart-search-back/pkg/errors"
|
||||
|
||||
"git.techease.ru/Smart-search/smart-search-back/internal/model"
|
||||
errs "git.techease.ru/Smart-search/smart-search-back/pkg/errors"
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
@@ -25,6 +24,14 @@ func NewInviteRepository(pool *pgxpool.Pool) InviteRepository {
|
||||
}
|
||||
|
||||
func (r *inviteRepository) Create(ctx context.Context, invite *model.InviteCode) error {
|
||||
return r.createWithExecutor(ctx, r.pool, invite)
|
||||
}
|
||||
|
||||
func (r *inviteRepository) CreateTx(ctx context.Context, tx pgx.Tx, invite *model.InviteCode) error {
|
||||
return r.createWithExecutor(ctx, tx, invite)
|
||||
}
|
||||
|
||||
func (r *inviteRepository) createWithExecutor(ctx context.Context, exec DBTX, invite *model.InviteCode) error {
|
||||
query := r.qb.Insert("invite_codes").Columns(
|
||||
"user_id", "code", "can_be_used_count", "expires_at",
|
||||
).Values(
|
||||
@@ -36,7 +43,7 @@ func (r *inviteRepository) Create(ctx context.Context, invite *model.InviteCode)
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
||||
}
|
||||
|
||||
err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan(&invite.ID, &invite.CreatedAt)
|
||||
err = exec.QueryRow(ctx, sqlQuery, args...).Scan(&invite.ID, &invite.CreatedAt)
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to create invite code", err)
|
||||
}
|
||||
|
||||
@@ -5,9 +5,8 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
|
||||
"smart-search-back/internal/model"
|
||||
errs "smart-search-back/pkg/errors"
|
||||
|
||||
"git.techease.ru/Smart-search/smart-search-back/internal/model"
|
||||
errs "git.techease.ru/Smart-search/smart-search-back/pkg/errors"
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
@@ -47,6 +46,14 @@ func (r *requestRepository) Create(ctx context.Context, req *model.Request) erro
|
||||
}
|
||||
|
||||
func (r *requestRepository) UpdateWithTZ(ctx context.Context, id uuid.UUID, tz string, generated bool) error {
|
||||
return r.updateWithTZExecutor(ctx, r.pool, id, tz, generated)
|
||||
}
|
||||
|
||||
func (r *requestRepository) UpdateWithTZTx(ctx context.Context, tx pgx.Tx, id uuid.UUID, tz string, generated bool) error {
|
||||
return r.updateWithTZExecutor(ctx, tx, id, tz, generated)
|
||||
}
|
||||
|
||||
func (r *requestRepository) updateWithTZExecutor(ctx context.Context, exec DBTX, id uuid.UUID, tz string, generated bool) error {
|
||||
query := r.qb.Update("requests_for_suppliers").
|
||||
Set("final_tz", tz).
|
||||
Set("generated_tz", generated).
|
||||
@@ -58,7 +65,7 @@ func (r *requestRepository) UpdateWithTZ(ctx context.Context, id uuid.UUID, tz s
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
||||
}
|
||||
|
||||
_, err = r.pool.Exec(ctx, sqlQuery, args...)
|
||||
_, err = exec.Exec(ctx, sqlQuery, args...)
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to update request", err)
|
||||
}
|
||||
|
||||
@@ -5,8 +5,8 @@ import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"smart-search-back/internal/model"
|
||||
errs "smart-search-back/pkg/errors"
|
||||
"git.techease.ru/Smart-search/smart-search-back/internal/model"
|
||||
errs "git.techease.ru/Smart-search/smart-search-back/pkg/errors"
|
||||
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
"github.com/jackc/pgx/v5"
|
||||
@@ -114,6 +114,47 @@ func (r *sessionRepository) Revoke(ctx context.Context, refreshToken string) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *sessionRepository) RevokeByAccessToken(ctx context.Context, accessToken string) error {
|
||||
query := r.qb.Update("sessions").
|
||||
Set("revoked_at", time.Now()).
|
||||
Where(sq.Eq{"access_token": accessToken})
|
||||
|
||||
sqlQuery, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
||||
}
|
||||
|
||||
_, err = r.pool.Exec(ctx, sqlQuery, args...)
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to revoke session", err)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *sessionRepository) IsAccessTokenValid(ctx context.Context, accessToken string) (bool, error) {
|
||||
query := r.qb.Select("COUNT(*)").
|
||||
From("sessions").
|
||||
Where(sq.And{
|
||||
sq.Eq{"access_token": accessToken},
|
||||
sq.Expr("revoked_at IS NULL"),
|
||||
sq.Expr("expires_at > now()"),
|
||||
})
|
||||
|
||||
sqlQuery, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return false, errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
||||
}
|
||||
|
||||
var count int
|
||||
err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan(&count)
|
||||
if err != nil {
|
||||
return false, errs.NewInternalError(errs.DatabaseError, "failed to check token validity", err)
|
||||
}
|
||||
|
||||
return count > 0, nil
|
||||
}
|
||||
|
||||
func (r *sessionRepository) DeleteExpired(ctx context.Context) (int, error) {
|
||||
query := r.qb.Delete("sessions").Where(sq.Or{
|
||||
sq.Expr("expires_at < now()"),
|
||||
|
||||
@@ -3,11 +3,11 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
|
||||
"smart-search-back/internal/model"
|
||||
errs "smart-search-back/pkg/errors"
|
||||
|
||||
"git.techease.ru/Smart-search/smart-search-back/internal/model"
|
||||
errs "git.techease.ru/Smart-search/smart-search-back/pkg/errors"
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
"github.com/google/uuid"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
@@ -24,6 +24,14 @@ func NewSupplierRepository(pool *pgxpool.Pool) SupplierRepository {
|
||||
}
|
||||
|
||||
func (r *supplierRepository) BulkInsert(ctx context.Context, requestID uuid.UUID, suppliers []*model.Supplier) error {
|
||||
return r.bulkInsertWithExecutor(ctx, r.pool, requestID, suppliers)
|
||||
}
|
||||
|
||||
func (r *supplierRepository) BulkInsertTx(ctx context.Context, tx pgx.Tx, requestID uuid.UUID, suppliers []*model.Supplier) error {
|
||||
return r.bulkInsertWithExecutor(ctx, tx, requestID, suppliers)
|
||||
}
|
||||
|
||||
func (r *supplierRepository) bulkInsertWithExecutor(ctx context.Context, exec DBTX, requestID uuid.UUID, suppliers []*model.Supplier) error {
|
||||
if len(suppliers) == 0 {
|
||||
return nil
|
||||
}
|
||||
@@ -41,7 +49,7 @@ func (r *supplierRepository) BulkInsert(ctx context.Context, requestID uuid.UUID
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
||||
}
|
||||
|
||||
_, err = r.pool.Exec(ctx, sqlQuery, args...)
|
||||
_, err = exec.Exec(ctx, sqlQuery, args...)
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to bulk insert suppliers", err)
|
||||
}
|
||||
|
||||
@@ -3,10 +3,10 @@ package repository
|
||||
import (
|
||||
"context"
|
||||
|
||||
"smart-search-back/internal/model"
|
||||
errs "smart-search-back/pkg/errors"
|
||||
|
||||
"git.techease.ru/Smart-search/smart-search-back/internal/model"
|
||||
errs "git.techease.ru/Smart-search/smart-search-back/pkg/errors"
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
@@ -23,6 +23,14 @@ func NewTokenUsageRepository(pool *pgxpool.Pool) TokenUsageRepository {
|
||||
}
|
||||
|
||||
func (r *tokenUsageRepository) Create(ctx context.Context, usage *model.TokenUsage) error {
|
||||
return r.createWithExecutor(ctx, r.pool, usage)
|
||||
}
|
||||
|
||||
func (r *tokenUsageRepository) CreateTx(ctx context.Context, tx pgx.Tx, usage *model.TokenUsage) error {
|
||||
return r.createWithExecutor(ctx, tx, usage)
|
||||
}
|
||||
|
||||
func (r *tokenUsageRepository) createWithExecutor(ctx context.Context, exec DBTX, usage *model.TokenUsage) error {
|
||||
query := r.qb.Insert("request_token_usage").Columns(
|
||||
"request_id", "request_token_count", "response_token_count", "token_cost", "type",
|
||||
).Values(
|
||||
@@ -34,7 +42,7 @@ func (r *tokenUsageRepository) Create(ctx context.Context, usage *model.TokenUsa
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
||||
}
|
||||
|
||||
err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan(&usage.ID, &usage.CreatedAt)
|
||||
err = exec.QueryRow(ctx, sqlQuery, args...).Scan(&usage.ID, &usage.CreatedAt)
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to create token usage", err)
|
||||
}
|
||||
|
||||
48
internal/repository/tx.go
Normal file
48
internal/repository/tx.go
Normal file
@@ -0,0 +1,48 @@
|
||||
package repository
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/jackc/pgx/v5"
|
||||
"github.com/jackc/pgx/v5/pgconn"
|
||||
"github.com/jackc/pgx/v5/pgxpool"
|
||||
)
|
||||
|
||||
type DBTX interface {
|
||||
Exec(ctx context.Context, sql string, arguments ...any) (pgconn.CommandTag, error)
|
||||
Query(ctx context.Context, sql string, args ...any) (pgx.Rows, error)
|
||||
QueryRow(ctx context.Context, sql string, args ...any) pgx.Row
|
||||
}
|
||||
|
||||
type TxManager struct {
|
||||
pool *pgxpool.Pool
|
||||
}
|
||||
|
||||
func NewTxManager(pool *pgxpool.Pool) *TxManager {
|
||||
return &TxManager{pool: pool}
|
||||
}
|
||||
|
||||
func (tm *TxManager) WithTx(ctx context.Context, fn func(tx pgx.Tx) error) error {
|
||||
tx, err := tm.pool.Begin(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
defer func() {
|
||||
if p := recover(); p != nil {
|
||||
_ = tx.Rollback(ctx)
|
||||
panic(p)
|
||||
}
|
||||
}()
|
||||
|
||||
if err := fn(tx); err != nil {
|
||||
_ = tx.Rollback(ctx)
|
||||
return err
|
||||
}
|
||||
|
||||
return tx.Commit(ctx)
|
||||
}
|
||||
|
||||
func (tm *TxManager) Pool() *pgxpool.Pool {
|
||||
return tm.pool
|
||||
}
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"smart-search-back/internal/model"
|
||||
"smart-search-back/pkg/crypto"
|
||||
errs "smart-search-back/pkg/errors"
|
||||
"git.techease.ru/Smart-search/smart-search-back/internal/model"
|
||||
"git.techease.ru/Smart-search/smart-search-back/pkg/crypto"
|
||||
errs "git.techease.ru/Smart-search/smart-search-back/pkg/errors"
|
||||
|
||||
sq "github.com/Masterminds/squirrel"
|
||||
"github.com/jackc/pgx/v5"
|
||||
@@ -123,20 +123,35 @@ func (r *userRepository) Create(ctx context.Context, user *model.User) error {
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateBalance(ctx context.Context, userID int, delta float64) error {
|
||||
return r.updateBalanceWithExecutor(ctx, r.pool, userID, delta)
|
||||
}
|
||||
|
||||
func (r *userRepository) UpdateBalanceTx(ctx context.Context, tx pgx.Tx, userID int, delta float64) error {
|
||||
return r.updateBalanceWithExecutor(ctx, tx, userID, delta)
|
||||
}
|
||||
|
||||
func (r *userRepository) updateBalanceWithExecutor(ctx context.Context, exec DBTX, userID int, delta float64) error {
|
||||
query := r.qb.Update("users").
|
||||
Set("balance", sq.Expr("balance + ?", delta)).
|
||||
Where(sq.Eq{"id": userID})
|
||||
Where(sq.And{
|
||||
sq.Eq{"id": userID},
|
||||
sq.Expr("balance + ? >= 0", delta),
|
||||
})
|
||||
|
||||
sqlQuery, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
||||
}
|
||||
|
||||
_, err = r.pool.Exec(ctx, sqlQuery, args...)
|
||||
result, err := exec.Exec(ctx, sqlQuery, args...)
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to update balance", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return errs.NewBusinessError(errs.InsufficientBalance, "insufficient balance")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
@@ -161,24 +176,47 @@ func (r *userRepository) GetBalance(ctx context.Context, userID int) (float64, e
|
||||
}
|
||||
|
||||
func (r *userRepository) IncrementInvitesIssued(ctx context.Context, userID int) error {
|
||||
return r.incrementInvitesIssuedWithExecutor(ctx, r.pool, userID)
|
||||
}
|
||||
|
||||
func (r *userRepository) IncrementInvitesIssuedTx(ctx context.Context, tx pgx.Tx, userID int) error {
|
||||
return r.incrementInvitesIssuedWithExecutor(ctx, tx, userID)
|
||||
}
|
||||
|
||||
func (r *userRepository) incrementInvitesIssuedWithExecutor(ctx context.Context, exec DBTX, userID int) error {
|
||||
query := r.qb.Update("users").
|
||||
Set("invites_issued", sq.Expr("invites_issued + 1")).
|
||||
Where(sq.Eq{"id": userID})
|
||||
Where(sq.And{
|
||||
sq.Eq{"id": userID},
|
||||
sq.Expr("invites_issued < invites_limit"),
|
||||
})
|
||||
|
||||
sqlQuery, args, err := query.ToSql()
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
||||
}
|
||||
|
||||
_, err = r.pool.Exec(ctx, sqlQuery, args...)
|
||||
result, err := exec.Exec(ctx, sqlQuery, args...)
|
||||
if err != nil {
|
||||
return errs.NewInternalError(errs.DatabaseError, "failed to increment invites issued", err)
|
||||
}
|
||||
|
||||
if result.RowsAffected() == 0 {
|
||||
return errs.NewBusinessError(errs.InviteLimitReached, "invite limit reached")
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (r *userRepository) CheckInviteLimit(ctx context.Context, userID int) (bool, error) {
|
||||
return r.checkInviteLimitWithExecutor(ctx, r.pool, userID)
|
||||
}
|
||||
|
||||
func (r *userRepository) CheckInviteLimitTx(ctx context.Context, tx pgx.Tx, userID int) (bool, error) {
|
||||
return r.checkInviteLimitWithExecutor(ctx, tx, userID)
|
||||
}
|
||||
|
||||
func (r *userRepository) checkInviteLimitWithExecutor(ctx context.Context, exec DBTX, userID int) (bool, error) {
|
||||
query := r.qb.Select("invites_issued", "invites_limit").
|
||||
From("users").
|
||||
Where(sq.Eq{"id": userID}).
|
||||
@@ -190,7 +228,7 @@ func (r *userRepository) CheckInviteLimit(ctx context.Context, userID int) (bool
|
||||
}
|
||||
|
||||
var issued, limit int
|
||||
err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan(&issued, &limit)
|
||||
err = exec.QueryRow(ctx, sqlQuery, args...).Scan(&issued, &limit)
|
||||
if err != nil {
|
||||
return false, errs.NewInternalError(errs.DatabaseError, "failed to check invite limit", err)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user