155 lines
3.9 KiB
Go
155 lines
3.9 KiB
Go
package service
|
|
|
|
import (
|
|
"context"
|
|
"math"
|
|
|
|
"github.com/google/uuid"
|
|
"smart-search-back/internal/ai"
|
|
"smart-search-back/internal/model"
|
|
"smart-search-back/internal/repository"
|
|
"smart-search-back/pkg/errors"
|
|
)
|
|
|
|
type requestService struct {
|
|
requestRepo repository.RequestRepository
|
|
supplierRepo repository.SupplierRepository
|
|
tokenUsageRepo repository.TokenUsageRepository
|
|
userRepo repository.UserRepository
|
|
openAI *ai.OpenAIClient
|
|
perplexity *ai.PerplexityClient
|
|
}
|
|
|
|
func NewRequestService(
|
|
requestRepo repository.RequestRepository,
|
|
supplierRepo repository.SupplierRepository,
|
|
tokenUsageRepo repository.TokenUsageRepository,
|
|
userRepo repository.UserRepository,
|
|
openAI *ai.OpenAIClient,
|
|
perplexity *ai.PerplexityClient,
|
|
) RequestService {
|
|
return &requestService{
|
|
requestRepo: requestRepo,
|
|
supplierRepo: supplierRepo,
|
|
tokenUsageRepo: tokenUsageRepo,
|
|
userRepo: userRepo,
|
|
openAI: openAI,
|
|
perplexity: perplexity,
|
|
}
|
|
}
|
|
|
|
func (s *requestService) CreateTZ(ctx context.Context, userID int, requestTxt string) (uuid.UUID, string, error) {
|
|
req := &model.Request{
|
|
UserID: userID,
|
|
RequestTxt: requestTxt,
|
|
}
|
|
|
|
if err := s.requestRepo.Create(ctx, req); err != nil {
|
|
return uuid.Nil, "", err
|
|
}
|
|
|
|
if requestTxt == "" {
|
|
return req.ID, "", nil
|
|
}
|
|
|
|
tzText, err := s.openAI.GenerateTZ(requestTxt)
|
|
if err != nil {
|
|
if err := s.requestRepo.UpdateWithTZ(ctx, req.ID, "", false); err != nil {
|
|
return req.ID, "", err
|
|
}
|
|
return req.ID, "", err
|
|
}
|
|
|
|
inputLen := len(requestTxt)
|
|
outputLen := len(tzText)
|
|
promptTokens := 500
|
|
|
|
inputTokens := int(math.Ceil(float64(inputLen) / 2.0))
|
|
outputTokens := int(math.Ceil(float64(outputLen) / 2.0))
|
|
|
|
totalTokens := inputTokens + outputTokens + promptTokens
|
|
tokenPrice := 25000.0 / 1000000.0
|
|
cost := float64(totalTokens) * tokenPrice
|
|
|
|
tokenUsage := &model.TokenUsage{
|
|
RequestID: req.ID,
|
|
RequestTokenCount: inputTokens + promptTokens,
|
|
ResponseTokenCount: outputTokens,
|
|
TokenCost: cost,
|
|
Type: "tz",
|
|
}
|
|
|
|
if err := s.tokenUsageRepo.Create(ctx, tokenUsage); err != nil {
|
|
return req.ID, "", err
|
|
}
|
|
|
|
if err := s.userRepo.UpdateBalance(ctx, userID, -cost); err != nil {
|
|
return req.ID, "", err
|
|
}
|
|
|
|
if err := s.requestRepo.UpdateWithTZ(ctx, req.ID, tzText, true); err != nil {
|
|
return req.ID, "", err
|
|
}
|
|
|
|
return req.ID, tzText, nil
|
|
}
|
|
|
|
func (s *requestService) ApproveTZ(ctx context.Context, requestID uuid.UUID, tzText string, userID int) ([]*model.Supplier, error) {
|
|
if err := s.requestRepo.UpdateFinalTZ(ctx, requestID, tzText); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
var suppliers []*model.Supplier
|
|
var promptTokens, responseTokens int
|
|
var err error
|
|
|
|
for attempt := 0; attempt < 3; attempt++ {
|
|
suppliers, promptTokens, responseTokens, err = s.perplexity.FindSuppliers(tzText)
|
|
if err == nil {
|
|
break
|
|
}
|
|
}
|
|
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if len(suppliers) == 0 {
|
|
return nil, errors.NewInternalError(errors.AIAPIError, "no suppliers found", nil)
|
|
}
|
|
|
|
if err := s.supplierRepo.BulkInsert(ctx, requestID, suppliers); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
tokenPrice := 25000.0 / 1000000.0
|
|
totalTokens := promptTokens + responseTokens
|
|
cost := float64(totalTokens) * tokenPrice
|
|
|
|
tokenUsage := &model.TokenUsage{
|
|
RequestID: requestID,
|
|
RequestTokenCount: promptTokens,
|
|
ResponseTokenCount: responseTokens,
|
|
TokenCost: cost,
|
|
Type: "suppliers",
|
|
}
|
|
|
|
if err := s.tokenUsageRepo.Create(ctx, tokenUsage); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
if err := s.userRepo.UpdateBalance(ctx, userID, -cost); err != nil {
|
|
return nil, err
|
|
}
|
|
|
|
return suppliers, nil
|
|
}
|
|
|
|
func (s *requestService) GetMailingList(ctx context.Context, userID int) ([]*model.Request, error) {
|
|
return s.requestRepo.GetByUserID(ctx, userID)
|
|
}
|
|
|
|
func (s *requestService) GetMailingListByID(ctx context.Context, requestID uuid.UUID) (*model.RequestDetail, error) {
|
|
return s.requestRepo.GetDetailByID(ctx, requestID)
|
|
}
|