176 lines
5.0 KiB
Go
176 lines
5.0 KiB
Go
package repository
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"time"
|
|
|
|
"git.techease.ru/Smart-search/smart-search-back/internal/model"
|
|
errs "git.techease.ru/Smart-search/smart-search-back/pkg/errors"
|
|
|
|
sq "github.com/Masterminds/squirrel"
|
|
"github.com/jackc/pgx/v5"
|
|
"github.com/jackc/pgx/v5/pgxpool"
|
|
)
|
|
|
|
type sessionRepository struct {
|
|
pool *pgxpool.Pool
|
|
qb sq.StatementBuilderType
|
|
}
|
|
|
|
func NewSessionRepository(pool *pgxpool.Pool) SessionRepository {
|
|
return &sessionRepository{
|
|
pool: pool,
|
|
qb: sq.StatementBuilder.PlaceholderFormat(sq.Dollar),
|
|
}
|
|
}
|
|
|
|
func (r *sessionRepository) Create(ctx context.Context, session *model.Session) error {
|
|
query := r.qb.Insert("sessions").Columns(
|
|
"user_id", "access_token", "refresh_token", "ip", "user_agent", "expires_at",
|
|
).Values(
|
|
session.UserID, session.AccessToken, session.RefreshToken,
|
|
session.IP, session.UserAgent, session.ExpiresAt,
|
|
).Suffix("RETURNING id")
|
|
|
|
sqlQuery, args, err := query.ToSql()
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
|
}
|
|
|
|
err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan(&session.ID)
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.DatabaseError, "failed to create session", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *sessionRepository) FindByRefreshToken(ctx context.Context, token string) (*model.Session, error) {
|
|
query := r.qb.Select(
|
|
"id", "user_id", "access_token", "refresh_token", "ip",
|
|
"user_agent", "created_at", "expires_at", "revoked_at",
|
|
).From("sessions").Where(sq.And{
|
|
sq.Eq{"refresh_token": token},
|
|
sq.Expr("revoked_at IS NULL"),
|
|
sq.Expr("expires_at > now()"),
|
|
})
|
|
|
|
sqlQuery, args, err := query.ToSql()
|
|
if err != nil {
|
|
return nil, errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
|
}
|
|
|
|
session := &model.Session{}
|
|
err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan(
|
|
&session.ID, &session.UserID, &session.AccessToken, &session.RefreshToken,
|
|
&session.IP, &session.UserAgent, &session.CreatedAt, &session.ExpiresAt,
|
|
&session.RevokedAt,
|
|
)
|
|
|
|
if errors.Is(err, pgx.ErrNoRows) {
|
|
return nil, errs.NewBusinessError(errs.RefreshInvalid, "refresh token is invalid or expired")
|
|
}
|
|
if err != nil {
|
|
return nil, errs.NewInternalError(errs.DatabaseError, "failed to find session", err)
|
|
}
|
|
|
|
return session, nil
|
|
}
|
|
|
|
func (r *sessionRepository) UpdateAccessToken(ctx context.Context, refreshToken, newAccessToken string) error {
|
|
query := r.qb.Update("sessions").
|
|
Set("access_token", newAccessToken).
|
|
Where(sq.Eq{"refresh_token": refreshToken})
|
|
|
|
sqlQuery, args, err := query.ToSql()
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
|
}
|
|
|
|
_, err = r.pool.Exec(ctx, sqlQuery, args...)
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.DatabaseError, "failed to update access token", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *sessionRepository) Revoke(ctx context.Context, refreshToken string) error {
|
|
query := r.qb.Update("sessions").
|
|
Set("revoked_at", time.Now()).
|
|
Where(sq.Eq{"refresh_token": refreshToken})
|
|
|
|
sqlQuery, args, err := query.ToSql()
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
|
}
|
|
|
|
_, err = r.pool.Exec(ctx, sqlQuery, args...)
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.DatabaseError, "failed to revoke session", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *sessionRepository) RevokeByAccessToken(ctx context.Context, accessToken string) error {
|
|
query := r.qb.Update("sessions").
|
|
Set("revoked_at", time.Now()).
|
|
Where(sq.Eq{"access_token": accessToken})
|
|
|
|
sqlQuery, args, err := query.ToSql()
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
|
}
|
|
|
|
_, err = r.pool.Exec(ctx, sqlQuery, args...)
|
|
if err != nil {
|
|
return errs.NewInternalError(errs.DatabaseError, "failed to revoke session", err)
|
|
}
|
|
|
|
return nil
|
|
}
|
|
|
|
func (r *sessionRepository) IsAccessTokenValid(ctx context.Context, accessToken string) (bool, error) {
|
|
query := r.qb.Select("COUNT(*)").
|
|
From("sessions").
|
|
Where(sq.And{
|
|
sq.Eq{"access_token": accessToken},
|
|
sq.Expr("revoked_at IS NULL"),
|
|
sq.Expr("expires_at > now()"),
|
|
})
|
|
|
|
sqlQuery, args, err := query.ToSql()
|
|
if err != nil {
|
|
return false, errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
|
}
|
|
|
|
var count int
|
|
err = r.pool.QueryRow(ctx, sqlQuery, args...).Scan(&count)
|
|
if err != nil {
|
|
return false, errs.NewInternalError(errs.DatabaseError, "failed to check token validity", err)
|
|
}
|
|
|
|
return count > 0, nil
|
|
}
|
|
|
|
func (r *sessionRepository) DeleteExpired(ctx context.Context) (int, error) {
|
|
query := r.qb.Delete("sessions").Where(sq.Or{
|
|
sq.Expr("expires_at < now()"),
|
|
sq.Expr("(revoked_at IS NOT NULL AND revoked_at < now() - interval '30 days')"),
|
|
})
|
|
|
|
sqlQuery, args, err := query.ToSql()
|
|
if err != nil {
|
|
return 0, errs.NewInternalError(errs.DatabaseError, "failed to build query", err)
|
|
}
|
|
|
|
result, err := r.pool.Exec(ctx, sqlQuery, args...)
|
|
if err != nil {
|
|
return 0, errs.NewInternalError(errs.DatabaseError, "failed to delete expired sessions", err)
|
|
}
|
|
|
|
return int(result.RowsAffected()), nil
|
|
}
|