From ff08bb225428ccfba5e9b9bbd5360d46ba01352c Mon Sep 17 00:00:00 2001 From: vallyenfail Date: Mon, 19 Jan 2026 19:14:51 +0300 Subject: [PATCH] add service --- api/proto/user/user.proto | 13 ++- internal/grpc/server.go | 2 +- internal/grpc/user_handler.go | 17 ++-- internal/model/supplier.go | 6 ++ internal/repository/interfaces.go | 1 + internal/repository/token_usage.go | 57 ++++++++++++ internal/service/interfaces.go | 1 + internal/service/user.go | 34 +++++-- pkg/pb/user/user.pb.go | 138 ++++++++++++++++++++--------- tests/repository_test.go | 58 ++++++++++++ tests/user_handler_test.go | 3 +- 11 files changed, 268 insertions(+), 62 deletions(-) diff --git a/api/proto/user/user.proto b/api/proto/user/user.proto index d884374..0db9045 100644 --- a/api/proto/user/user.proto +++ b/api/proto/user/user.proto @@ -44,8 +44,13 @@ message GetBalanceStatisticsRequest { int64 user_id = 1; } -message GetBalanceStatisticsResponse { - double balance = 1; - int32 total_requests = 2; - double total_spent = 3; +message WriteOffHistoryItem { + string operation_id = 1; + string data = 2; + double amount = 3; +} + +message GetBalanceStatisticsResponse { + double average_cost = 1; + repeated WriteOffHistoryItem write_off_history = 2; } diff --git a/internal/grpc/server.go b/internal/grpc/server.go index 3475cf9..200426e 100644 --- a/internal/grpc/server.go +++ b/internal/grpc/server.go @@ -59,7 +59,7 @@ func NewHandlers(pool *pgxpool.Pool, jwtSecret, cryptoSecret, openAIKey, perplex perplexityClient := ai.NewPerplexityClient(perplexityKey) authService := service.NewAuthService(userRepo, sessionRepo, inviteRepo, txManager, jwtSecret, cryptoSecret) - userService := service.NewUserService(userRepo, requestRepo, cryptoSecret) + userService := service.NewUserService(userRepo, requestRepo, tokenUsageRepo, cryptoSecret) inviteService := service.NewInviteService(inviteRepo, userRepo, txManager) requestService := service.NewRequestService(requestRepo, supplierRepo, tokenUsageRepo, userRepo, openAIClient, perplexityClient, txManager) supplierService := service.NewSupplierService(supplierRepo, requestRepo) diff --git a/internal/grpc/user_handler.go b/internal/grpc/user_handler.go index 81862ec..217b32b 100644 --- a/internal/grpc/user_handler.go +++ b/internal/grpc/user_handler.go @@ -48,19 +48,22 @@ func (h *UserHandler) GetStatistics(ctx context.Context, req *pb.GetStatisticsRe } func (h *UserHandler) GetBalanceStatistics(ctx context.Context, req *pb.GetBalanceStatisticsRequest) (*pb.GetBalanceStatisticsResponse, error) { - balance, err := h.userService.GetBalance(ctx, int(req.UserId)) + stats, err := h.userService.GetBalanceStatistics(ctx, int(req.UserId)) if err != nil { return nil, errors.ToGRPCError(err, h.logger, "UserService.GetBalanceStatistics") } - stats, err := h.userService.GetStatistics(ctx, int(req.UserId)) - if err != nil { - return nil, errors.ToGRPCError(err, h.logger, "UserService.GetBalanceStatistics") + history := make([]*pb.WriteOffHistoryItem, 0, len(stats.WriteOffHistory)) + for _, item := range stats.WriteOffHistory { + history = append(history, &pb.WriteOffHistoryItem{ + OperationId: item.OperationID, + Data: item.Data, + Amount: item.Amount, + }) } return &pb.GetBalanceStatisticsResponse{ - Balance: balance, - TotalRequests: int32(stats.RequestsCount), - TotalSpent: 0, + AverageCost: stats.AverageCost, + WriteOffHistory: history, }, nil } diff --git a/internal/model/supplier.go b/internal/model/supplier.go index 285c6a1..95d9413 100644 --- a/internal/model/supplier.go +++ b/internal/model/supplier.go @@ -26,3 +26,9 @@ type TokenUsage struct { Type string CreatedAt time.Time } + +type WriteOffHistory struct { + OperationID string + Data string + Amount float64 +} diff --git a/internal/repository/interfaces.go b/internal/repository/interfaces.go index 785d63c..42333c1 100644 --- a/internal/repository/interfaces.go +++ b/internal/repository/interfaces.go @@ -64,4 +64,5 @@ type SupplierRepository interface { type TokenUsageRepository interface { Create(ctx context.Context, usage *model.TokenUsage) error CreateTx(ctx context.Context, tx pgx.Tx, usage *model.TokenUsage) error + GetBalanceStatistics(ctx context.Context, userID int) (averageCost float64, history []*model.WriteOffHistory, err error) } diff --git a/internal/repository/token_usage.go b/internal/repository/token_usage.go index 30ec8e5..1d2340a 100644 --- a/internal/repository/token_usage.go +++ b/internal/repository/token_usage.go @@ -2,6 +2,7 @@ package repository import ( "context" + "strconv" "git.techease.ru/Smart-search/smart-search-back/internal/model" errs "git.techease.ru/Smart-search/smart-search-back/pkg/errors" @@ -49,3 +50,59 @@ func (r *tokenUsageRepository) createWithExecutor(ctx context.Context, exec DBTX return nil } + +func (r *tokenUsageRepository) GetBalanceStatistics(ctx context.Context, userID int) (float64, []*model.WriteOffHistory, error) { + avgQuery := r.qb.Select("COALESCE(AVG(COALESCE(rtu.token_cost, 0)), 0)"). + From("request_token_usage rtu"). + Join("requests_for_suppliers rfs ON rtu.request_id = rfs.id"). + Where(sq.Eq{"rfs.user_id": userID}) + + avgSQL, avgArgs, err := avgQuery.ToSql() + if err != nil { + return 0, nil, errs.NewInternalError(errs.DatabaseError, "failed to build average query", err) + } + + var averageCost float64 + if err := r.pool.QueryRow(ctx, avgSQL, avgArgs...).Scan(&averageCost); err != nil { + return 0, nil, errs.NewInternalError(errs.DatabaseError, "failed to get average cost", err) + } + + historyQuery := r.qb.Select( + "rtu.id", + "TO_CHAR(rtu.created_at, 'DD-MM-YYYY')", + "COALESCE(rtu.token_cost, 0)", + ). + From("request_token_usage rtu"). + Join("requests_for_suppliers rfs ON rtu.request_id = rfs.id"). + Where(sq.Eq{"rfs.user_id": userID}). + OrderBy("rtu.created_at DESC"). + Limit(8) + + historySQL, historyArgs, err := historyQuery.ToSql() + if err != nil { + return 0, nil, errs.NewInternalError(errs.DatabaseError, "failed to build history query", err) + } + + rows, err := r.pool.Query(ctx, historySQL, historyArgs...) + if err != nil { + return 0, nil, errs.NewInternalError(errs.DatabaseError, "failed to get write-off history", err) + } + defer rows.Close() + + var history []*model.WriteOffHistory + for rows.Next() { + var operationID int + var item model.WriteOffHistory + if err := rows.Scan(&operationID, &item.Data, &item.Amount); err != nil { + return 0, nil, errs.NewInternalError(errs.DatabaseError, "failed to scan write-off history", err) + } + item.OperationID = strconv.Itoa(operationID) + history = append(history, &item) + } + + if err := rows.Err(); err != nil { + return 0, nil, errs.NewInternalError(errs.DatabaseError, "failed to iterate write-off history", err) + } + + return averageCost, history, nil +} diff --git a/internal/service/interfaces.go b/internal/service/interfaces.go index e581dee..2a107ed 100644 --- a/internal/service/interfaces.go +++ b/internal/service/interfaces.go @@ -20,6 +20,7 @@ type UserService interface { GetInfo(ctx context.Context, userID int) (*UserInfo, error) GetBalance(ctx context.Context, userID int) (float64, error) GetStatistics(ctx context.Context, userID int) (*Statistics, error) + GetBalanceStatistics(ctx context.Context, userID int) (*BalanceStatistics, error) } type InviteService interface { diff --git a/internal/service/user.go b/internal/service/user.go index fcb19f0..9f2ccda 100644 --- a/internal/service/user.go +++ b/internal/service/user.go @@ -3,14 +3,16 @@ package service import ( "context" + "git.techease.ru/Smart-search/smart-search-back/internal/model" "git.techease.ru/Smart-search/smart-search-back/internal/repository" "git.techease.ru/Smart-search/smart-search-back/pkg/crypto" ) type userService struct { - userRepo repository.UserRepository - requestRepo repository.RequestRepository - cryptoHelper *crypto.Crypto + userRepo repository.UserRepository + requestRepo repository.RequestRepository + tokenUsageRepo repository.TokenUsageRepository + cryptoHelper *crypto.Crypto } type UserInfo struct { @@ -27,11 +29,17 @@ type Statistics struct { CreatedTZ int } -func NewUserService(userRepo repository.UserRepository, requestRepo repository.RequestRepository, cryptoSecret string) UserService { +type BalanceStatistics struct { + AverageCost float64 + WriteOffHistory []*model.WriteOffHistory +} + +func NewUserService(userRepo repository.UserRepository, requestRepo repository.RequestRepository, tokenUsageRepo repository.TokenUsageRepository, cryptoSecret string) UserService { return &userService{ - userRepo: userRepo, - requestRepo: requestRepo, - cryptoHelper: crypto.NewCrypto(cryptoSecret), + userRepo: userRepo, + requestRepo: requestRepo, + tokenUsageRepo: tokenUsageRepo, + cryptoHelper: crypto.NewCrypto(cryptoSecret), } } @@ -81,3 +89,15 @@ func (s *userService) GetStatistics(ctx context.Context, userID int) (*Statistic CreatedTZ: createdTZ, }, nil } + +func (s *userService) GetBalanceStatistics(ctx context.Context, userID int) (*BalanceStatistics, error) { + averageCost, history, err := s.tokenUsageRepo.GetBalanceStatistics(ctx, userID) + if err != nil { + return nil, err + } + + return &BalanceStatistics{ + AverageCost: averageCost, + WriteOffHistory: history, + }, nil +} diff --git a/pkg/pb/user/user.pb.go b/pkg/pb/user/user.pb.go index 863b9ba..7b7c764 100644 --- a/pkg/pb/user/user.pb.go +++ b/pkg/pb/user/user.pb.go @@ -385,18 +385,77 @@ func (x *GetBalanceStatisticsRequest) GetUserId() int64 { return 0 } -type GetBalanceStatisticsResponse struct { +type WriteOffHistoryItem struct { state protoimpl.MessageState `protogen:"open.v1"` - Balance float64 `protobuf:"fixed64,1,opt,name=balance,proto3" json:"balance,omitempty"` - TotalRequests int32 `protobuf:"varint,2,opt,name=total_requests,json=totalRequests,proto3" json:"total_requests,omitempty"` - TotalSpent float64 `protobuf:"fixed64,3,opt,name=total_spent,json=totalSpent,proto3" json:"total_spent,omitempty"` + OperationId string `protobuf:"bytes,1,opt,name=operation_id,json=operationId,proto3" json:"operation_id,omitempty"` + Data string `protobuf:"bytes,2,opt,name=data,proto3" json:"data,omitempty"` + Amount float64 `protobuf:"fixed64,3,opt,name=amount,proto3" json:"amount,omitempty"` unknownFields protoimpl.UnknownFields sizeCache protoimpl.SizeCache } +func (x *WriteOffHistoryItem) Reset() { + *x = WriteOffHistoryItem{} + mi := &file_user_user_proto_msgTypes[7] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *WriteOffHistoryItem) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*WriteOffHistoryItem) ProtoMessage() {} + +func (x *WriteOffHistoryItem) ProtoReflect() protoreflect.Message { + mi := &file_user_user_proto_msgTypes[7] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use WriteOffHistoryItem.ProtoReflect.Descriptor instead. +func (*WriteOffHistoryItem) Descriptor() ([]byte, []int) { + return file_user_user_proto_rawDescGZIP(), []int{7} +} + +func (x *WriteOffHistoryItem) GetOperationId() string { + if x != nil { + return x.OperationId + } + return "" +} + +func (x *WriteOffHistoryItem) GetData() string { + if x != nil { + return x.Data + } + return "" +} + +func (x *WriteOffHistoryItem) GetAmount() float64 { + if x != nil { + return x.Amount + } + return 0 +} + +type GetBalanceStatisticsResponse struct { + state protoimpl.MessageState `protogen:"open.v1"` + AverageCost float64 `protobuf:"fixed64,1,opt,name=average_cost,json=averageCost,proto3" json:"average_cost,omitempty"` + WriteOffHistory []*WriteOffHistoryItem `protobuf:"bytes,2,rep,name=write_off_history,json=writeOffHistory,proto3" json:"write_off_history,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + func (x *GetBalanceStatisticsResponse) Reset() { *x = GetBalanceStatisticsResponse{} - mi := &file_user_user_proto_msgTypes[7] + mi := &file_user_user_proto_msgTypes[8] ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) ms.StoreMessageInfo(mi) } @@ -408,7 +467,7 @@ func (x *GetBalanceStatisticsResponse) String() string { func (*GetBalanceStatisticsResponse) ProtoMessage() {} func (x *GetBalanceStatisticsResponse) ProtoReflect() protoreflect.Message { - mi := &file_user_user_proto_msgTypes[7] + mi := &file_user_user_proto_msgTypes[8] if x != nil { ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) if ms.LoadMessageInfo() == nil { @@ -421,28 +480,21 @@ func (x *GetBalanceStatisticsResponse) ProtoReflect() protoreflect.Message { // Deprecated: Use GetBalanceStatisticsResponse.ProtoReflect.Descriptor instead. func (*GetBalanceStatisticsResponse) Descriptor() ([]byte, []int) { - return file_user_user_proto_rawDescGZIP(), []int{7} + return file_user_user_proto_rawDescGZIP(), []int{8} } -func (x *GetBalanceStatisticsResponse) GetBalance() float64 { +func (x *GetBalanceStatisticsResponse) GetAverageCost() float64 { if x != nil { - return x.Balance + return x.AverageCost } return 0 } -func (x *GetBalanceStatisticsResponse) GetTotalRequests() int32 { +func (x *GetBalanceStatisticsResponse) GetWriteOffHistory() []*WriteOffHistoryItem { if x != nil { - return x.TotalRequests + return x.WriteOffHistory } - return 0 -} - -func (x *GetBalanceStatisticsResponse) GetTotalSpent() float64 { - if x != nil { - return x.TotalSpent - } - return 0 + return nil } var File_user_user_proto protoreflect.FileDescriptor @@ -471,12 +523,14 @@ const file_user_user_proto_rawDesc = "" + "\vtotal_spent\x18\x04 \x01(\x01R\n" + "totalSpent\"6\n" + "\x1bGetBalanceStatisticsRequest\x12\x17\n" + - "\auser_id\x18\x01 \x01(\x03R\x06userId\"\x80\x01\n" + - "\x1cGetBalanceStatisticsResponse\x12\x18\n" + - "\abalance\x18\x01 \x01(\x01R\abalance\x12%\n" + - "\x0etotal_requests\x18\x02 \x01(\x05R\rtotalRequests\x12\x1f\n" + - "\vtotal_spent\x18\x03 \x01(\x01R\n" + - "totalSpent2\xaf\x02\n" + + "\auser_id\x18\x01 \x01(\x03R\x06userId\"d\n" + + "\x13WriteOffHistoryItem\x12!\n" + + "\foperation_id\x18\x01 \x01(\tR\voperationId\x12\x12\n" + + "\x04data\x18\x02 \x01(\tR\x04data\x12\x16\n" + + "\x06amount\x18\x03 \x01(\x01R\x06amount\"\x88\x01\n" + + "\x1cGetBalanceStatisticsResponse\x12!\n" + + "\faverage_cost\x18\x01 \x01(\x01R\vaverageCost\x12E\n" + + "\x11write_off_history\x18\x02 \x03(\v2\x19.user.WriteOffHistoryItemR\x0fwriteOffHistory2\xaf\x02\n" + "\vUserService\x126\n" + "\aGetInfo\x12\x14.user.GetInfoRequest\x1a\x15.user.GetInfoResponse\x12?\n" + "\n" + @@ -496,7 +550,7 @@ func file_user_user_proto_rawDescGZIP() []byte { return file_user_user_proto_rawDescData } -var file_user_user_proto_msgTypes = make([]protoimpl.MessageInfo, 8) +var file_user_user_proto_msgTypes = make([]protoimpl.MessageInfo, 9) var file_user_user_proto_goTypes = []any{ (*GetInfoRequest)(nil), // 0: user.GetInfoRequest (*GetInfoResponse)(nil), // 1: user.GetInfoResponse @@ -505,22 +559,24 @@ var file_user_user_proto_goTypes = []any{ (*GetStatisticsRequest)(nil), // 4: user.GetStatisticsRequest (*GetStatisticsResponse)(nil), // 5: user.GetStatisticsResponse (*GetBalanceStatisticsRequest)(nil), // 6: user.GetBalanceStatisticsRequest - (*GetBalanceStatisticsResponse)(nil), // 7: user.GetBalanceStatisticsResponse + (*WriteOffHistoryItem)(nil), // 7: user.WriteOffHistoryItem + (*GetBalanceStatisticsResponse)(nil), // 8: user.GetBalanceStatisticsResponse } var file_user_user_proto_depIdxs = []int32{ - 0, // 0: user.UserService.GetInfo:input_type -> user.GetInfoRequest - 2, // 1: user.UserService.GetBalance:input_type -> user.GetBalanceRequest - 4, // 2: user.UserService.GetStatistics:input_type -> user.GetStatisticsRequest - 6, // 3: user.UserService.GetBalanceStatistics:input_type -> user.GetBalanceStatisticsRequest - 1, // 4: user.UserService.GetInfo:output_type -> user.GetInfoResponse - 3, // 5: user.UserService.GetBalance:output_type -> user.GetBalanceResponse - 5, // 6: user.UserService.GetStatistics:output_type -> user.GetStatisticsResponse - 7, // 7: user.UserService.GetBalanceStatistics:output_type -> user.GetBalanceStatisticsResponse - 4, // [4:8] is the sub-list for method output_type - 0, // [0:4] is the sub-list for method input_type - 0, // [0:0] is the sub-list for extension type_name - 0, // [0:0] is the sub-list for extension extendee - 0, // [0:0] is the sub-list for field type_name + 7, // 0: user.GetBalanceStatisticsResponse.write_off_history:type_name -> user.WriteOffHistoryItem + 0, // 1: user.UserService.GetInfo:input_type -> user.GetInfoRequest + 2, // 2: user.UserService.GetBalance:input_type -> user.GetBalanceRequest + 4, // 3: user.UserService.GetStatistics:input_type -> user.GetStatisticsRequest + 6, // 4: user.UserService.GetBalanceStatistics:input_type -> user.GetBalanceStatisticsRequest + 1, // 5: user.UserService.GetInfo:output_type -> user.GetInfoResponse + 3, // 6: user.UserService.GetBalance:output_type -> user.GetBalanceResponse + 5, // 7: user.UserService.GetStatistics:output_type -> user.GetStatisticsResponse + 8, // 8: user.UserService.GetBalanceStatistics:output_type -> user.GetBalanceStatisticsResponse + 5, // [5:9] is the sub-list for method output_type + 1, // [1:5] is the sub-list for method input_type + 1, // [1:1] is the sub-list for extension type_name + 1, // [1:1] is the sub-list for extension extendee + 0, // [0:1] is the sub-list for field type_name } func init() { file_user_user_proto_init() } @@ -534,7 +590,7 @@ func file_user_user_proto_init() { GoPackagePath: reflect.TypeOf(x{}).PkgPath(), RawDescriptor: unsafe.Slice(unsafe.StringData(file_user_user_proto_rawDesc), len(file_user_user_proto_rawDesc)), NumEnums: 0, - NumMessages: 8, + NumMessages: 9, NumExtensions: 0, NumServices: 1, }, diff --git a/tests/repository_test.go b/tests/repository_test.go index 8adfd38..97f1ccf 100644 --- a/tests/repository_test.go +++ b/tests/repository_test.go @@ -233,6 +233,64 @@ func (s *IntegrationSuite) TestRepository_TokenUsageCreate() { s.NoError(err) } +func (s *IntegrationSuite) TestRepository_TokenUsageGetBalanceStatistics() { + tokenRepo := repository.NewTokenUsageRepository(s.pool) + requestRepo := repository.NewRequestRepository(s.pool) + ctx := context.Background() + + var userID int + err := s.pool.QueryRow(ctx, "SELECT id FROM users LIMIT 1").Scan(&userID) + s.Require().NoError(err) + + req := &model.Request{ + UserID: userID, + RequestTxt: "Test request for balance statistics", + } + err = requestRepo.Create(ctx, req) + s.Require().NoError(err) + + usage1 := &model.TokenUsage{ + RequestID: req.ID, + RequestTokenCount: 100, + ResponseTokenCount: 50, + TokenCost: 10.0, + Type: "openai", + } + err = tokenRepo.Create(ctx, usage1) + s.Require().NoError(err) + + usage2 := &model.TokenUsage{ + RequestID: req.ID, + RequestTokenCount: 200, + ResponseTokenCount: 100, + TokenCost: 20.0, + Type: "perplexity", + } + err = tokenRepo.Create(ctx, usage2) + s.Require().NoError(err) + + averageCost, history, err := tokenRepo.GetBalanceStatistics(ctx, userID) + s.NoError(err) + s.Greater(averageCost, 0.0) + s.GreaterOrEqual(len(history), 2) + + for _, item := range history { + s.NotEmpty(item.OperationID) + s.NotEmpty(item.Data) + s.GreaterOrEqual(item.Amount, 0.0) + } +} + +func (s *IntegrationSuite) TestRepository_TokenUsageGetBalanceStatisticsEmpty() { + tokenRepo := repository.NewTokenUsageRepository(s.pool) + ctx := context.Background() + + averageCost, history, err := tokenRepo.GetBalanceStatistics(ctx, 999999) + s.NoError(err) + s.Equal(0.0, averageCost) + s.Empty(history) +} + func (s *IntegrationSuite) TestRepository_UserCreate() { userRepo := repository.NewUserRepository(s.pool, testCryptoSecret) ctx := context.Background() diff --git a/tests/user_handler_test.go b/tests/user_handler_test.go index dbf90d9..4dae08f 100644 --- a/tests/user_handler_test.go +++ b/tests/user_handler_test.go @@ -72,8 +72,7 @@ func (s *IntegrationSuite) TestUserHandler_GetBalanceStatistics() { } s.NotNil(resp) - s.GreaterOrEqual(resp.Balance, 0.0) - s.GreaterOrEqual(resp.TotalRequests, int32(0)) + s.GreaterOrEqual(resp.AverageCost, 0.0) } func (s *IntegrationSuite) TestUserHandler_GetInfoWithValidUser() {