From 30138bab38927b076a1614bded0930181eecabb8 Mon Sep 17 00:00:00 2001 From: xds Date: Fri, 13 Feb 2026 11:18:11 +0300 Subject: [PATCH] feat: Introduce generation grouping, enabling multiple generations per request via a new `count` parameter and retrieval by group ID. --- .../generation_router.cpython-313.pyc | Bin 10511 -> 11442 bytes api/endpoints/generation_router.py | 37 ++++++++++++------ api/models/GenerationRequest.py | 8 +++- .../GenerationRequest.cpython-313.pyc | Bin 3258 -> 3698 bytes .../generation_service.cpython-313.pyc | Bin 26145 -> 27060 bytes api/service/generation_service.py | 21 ++++++++-- models/Generation.py | 1 + models/__pycache__/Generation.cpython-313.pyc | Bin 3268 -> 3329 bytes .../generation_repo.cpython-313.pyc | Bin 5472 -> 6234 bytes repos/generation_repo.py | 8 ++++ 10 files changed, 58 insertions(+), 17 deletions(-) diff --git a/api/endpoints/__pycache__/generation_router.cpython-313.pyc b/api/endpoints/__pycache__/generation_router.cpython-313.pyc index fb99eafe262ca3d83034cdbbbc2883a81c565f81..f9e6ebaa2cb1c7ea7f1fe8d856f873845621c617 100644 GIT binary patch delta 2790 zcmZ`)O-vg{6y8~{jlnql7>5lISZsrF2!;d`!cS-jX$&}!$U~~6Z0fjpAx^=zW=++G zRFL)%sj1R7Q>mBUdT3gy>W?ouRi)~wQdJ{T(X4t%)l+X&q^Z(V=k4M(pkyU~`{wQY z`QCdoUZS)^%P zB9bmUBF*C-;?d>uh?jU}Z*&wQ?S4F*+|q*#9cg9poDu(c8)=g`S@3XpAFk*`xU!EY z0XY!8VEWP?UI7t%fhZk?TX4;c4$w;?&m~oe?IR*6>Kbr<&Hv!LN?q1u`74XBc!+Xun>qq`Nf&= zxU=ZUO05mqan~MVJKPn=^cPTK60Dd3>&iDt~D zDd1zz9@dVc8Fy(%)xCQX_RWaKx)S45^R#0^d(**3q4VU}Y57Nfc5ld0E~L~Q12oWy6c28954=mo+hL&gBY0OJ5~g^?4~>++tN1W^Jo znxKCH7y_7Jpq4P$)r^WvTFu7h%qbE4>3c4e%kA$$M^A#=9@4x{+GvZ-piyKsPU}S9 zkOXeeVsI8dJd*~2>raPE^oIClcU9*c?3T_Y{l(_+mS!)AGZtoBcuJY@4*PA{=%dPZ zns2VE77E|YH*GF)(_fr!`cLygpNo4*Q29JMp7+$nuzK{2Dc5d>R%1bO*N|@z(Ewd;fOkO`Ur|fKLOlL05 z#OLRggfyEbQlXT7DZp#4V)G44R<7pImCO3UQ9m#4Kt{lr<|QU+;_gIh?A14o5TqJD zmo20d%6u%bxR4>WOcAh=mmolKHhL5b%cVIP#K`vPR8^u9#NdY5zb5u?h0HVRQF3fl2sCWI&%Z8 z0KTRVJqqflFL*b2QcGPvUqEDPtxDgYYA{^9eGY)WRL#Z-i^zIaN3Zxib&y)Cj%SiV zC6&melc}s4)H_&DKk4%w zj&_A&ljmY1=O(9KoSf2m>jT_krkrsG?1PN(;=D3MK4Lo7Cr%Br%|L@Kz7m=Lm#O63 zrIv4yZE&SNk514^f6^UAk4v3s0c}}0erS^iannQp^k=ce@%CGFH|z4{P5G|yN?jg} z{t=$cpHtSuvkyZ9`N6a6p^3k(HYB5au1yZ!zjdG9w1Dua;8ttLYv8 O6cXvhwwyp2yZ-^s=Uy2A delta 1791 zcmZ`)No*Tc7@jvCPaG#s92{q{lX_<2-3eK2Ck;uLG;Km{a=?$Z9crca5V^#0^k#w( zs3<2cfKYh1sy74(1Zs<~95{0*Jyb%n5X%jT0}%lv)C=!_6Whh5Bh5Ge{M+|0Z|tvL z{-Drzmk)popugSx0sI+-6H(~<7v z(tmS}cjTI4xnAM5q(ueA=&xc#?ZQeckE%~&jor;G6JVh?#GrGgSp|yHv(hM!ot?Jp z=3bmVtovCP!{UrJA#b;4aZdKf7sdl zraEHN&C2N+X(;Pry9wZFJhLdW4K$}e%PO#$=Q^S@Ec!K=m@}cPg>8^7q!&fKM-%jc z<}Uf$V>-&nIMb1HVXHThrI(zykdwOQ%k*vec{EbLFGocbp$}X$-V<=9LI4VYN|(B8 zauUQOz1{sYiqsF>`{;=<(*T3$)FDg$#>=O8N?U@ECR?1|Vy|wu0?f#X}A~>}?kNH3c zuH+Q;dQ&~XLvkGIYu>-bi8COd1y}(%#~^{o1$cOlo1sPDXgCg5-po2kPXnw0Tm*pc zm@LzqzKNL)5FZ*Bw*Z_1c!q)5#h}$V{oUtBg^!$rsJ;=HL~5&`H$5w*l3|)k+1RNV zWxF()MCpxS>?EfK)`UbbBvt@a1`WrSv13Y10t!A@@&XtEUZVd5->N(PSv1$;d8l!b zH*Swsw%coWdy#Grg-<`WvpjG-+8$^xHK9cB4t-g-#;!}~9Q{N27JWv)QOD60`lA}{ zDT0@Md}zP6j@IalT6n49uGY%NcCoxy-DMkT)9c`V72pa0k2aXpkWHgR42x9gEiD*@ zb(kIoFpbt_kEIHM0I$~%wO^3D%EXd`di{g+w{YV!KMPX(L^lRlA%?9xC*c)N^Q%{*<&(m_A6=rFU7Te*wM}Z&UyP diff --git a/api/endpoints/generation_router.py b/api/endpoints/generation_router.py index 85c4f61..79f68bd 100644 --- a/api/endpoints/generation_router.py +++ b/api/endpoints/generation_router.py @@ -8,7 +8,7 @@ from api import service from api.dependency import get_generation_service, get_project_id, get_dao from repos.dao import DAO -from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest +from api.models.GenerationRequest import GenerationResponse, GenerationRequest, GenerationsResponse, PromptResponse, PromptRequest, GenerationGroupResponse from api.service.generation_service import GenerationService from models.Generation import Generation @@ -68,12 +68,12 @@ async def get_generations(character_id: Optional[str] = None, limit: int = 10, o return await generation_service.get_generations(character_id, limit=limit, offset=offset, user_id=user_id_filter, project_id=project_id) -@router.post("/_run", response_model=GenerationResponse) +@router.post("/_run", response_model=GenerationGroupResponse) async def post_generation(generation: GenerationRequest, request: Request, generation_service: GenerationService = Depends(get_generation_service), current_user: dict = Depends(get_current_user), project_id: Optional[str] = Depends(get_project_id), - dao: DAO = Depends(get_dao)) -> GenerationResponse: + dao: DAO = Depends(get_dao)) -> GenerationGroupResponse: logger.info(f"post_generation (run) called. LinkedCharId: {generation.linked_character_id}, PromptLength: {len(generation.prompt)}") if project_id: @@ -85,16 +85,6 @@ async def post_generation(generation: GenerationRequest, request: Request, return await generation_service.create_generation_task(generation, user_id=str(current_user.get("_id"))) -@router.get("/{generation_id}", response_model=GenerationResponse) -async def get_generation(generation_id: str, - generation_service: GenerationService = Depends(get_generation_service), - current_user: dict = Depends(get_current_user)) -> GenerationResponse: - logger.debug(f"get_generation called for ID: {generation_id}") - gen = await generation_service.get_generation(generation_id) - if gen and gen.created_by != str(current_user["_id"]): - raise HTTPException(status_code=403, detail="Access denied") - return gen - @router.get("/running") async def get_running_generations(request: Request, @@ -113,6 +103,27 @@ async def get_running_generations(request: Request, return await generation_service.get_running_generations(user_id=user_id_filter, project_id=project_id) +@router.get("/group/{group_id}", response_model=GenerationGroupResponse) +async def get_generation_group(group_id: str, + generation_service: GenerationService = Depends(get_generation_service), + current_user: dict = Depends(get_current_user)): + logger.info(f"get_generation_group called for group_id: {group_id}") + generations = await generation_service.dao.generations.get_generations_by_group(group_id) + gen_responses = [GenerationResponse(**gen.model_dump()) for gen in generations] + return GenerationGroupResponse(generation_group_id=group_id, generations=gen_responses) + + +@router.get("/{generation_id}", response_model=GenerationResponse) +async def get_generation(generation_id: str, + generation_service: GenerationService = Depends(get_generation_service), + current_user: dict = Depends(get_current_user)) -> GenerationResponse: + logger.debug(f"get_generation called for ID: {generation_id}") + gen = await generation_service.get_generation(generation_id) + if gen and gen.created_by != str(current_user["_id"]): + raise HTTPException(status_code=403, detail="Access denied") + return gen + + @router.post("/import", response_model=GenerationResponse) diff --git a/api/models/GenerationRequest.py b/api/models/GenerationRequest.py index 40e9d18..33a4010 100644 --- a/api/models/GenerationRequest.py +++ b/api/models/GenerationRequest.py @@ -1,7 +1,7 @@ from datetime import datetime, UTC from typing import List, Optional -from pydantic import BaseModel +from pydantic import BaseModel, Field from models.Asset import Asset from models.Generation import GenerationStatus @@ -17,6 +17,7 @@ class GenerationRequest(BaseModel): use_profile_image: bool = True assets_list: List[str] project_id: Optional[str] = None + count: int = Field(default=1, ge=1, le=10) class GenerationsResponse(BaseModel): @@ -45,10 +46,15 @@ class GenerationResponse(BaseModel): progress: int = 0 cost: Optional[float] = None created_by: Optional[str] = None + generation_group_id: Optional[str] = None created_at: datetime = datetime.now(UTC) updated_at: datetime = datetime.now(UTC) +class GenerationGroupResponse(BaseModel): + generation_group_id: str + generations: List[GenerationResponse] + class PromptRequest(BaseModel): prompt: str diff --git a/api/models/__pycache__/GenerationRequest.cpython-313.pyc b/api/models/__pycache__/GenerationRequest.cpython-313.pyc index b958d6046962717bf8d4280f85aad8b55bb9d2cc..0726ddb09aee4939f0106e13d72247b80b8f7d21 100644 GIT binary patch delta 1779 zcma)6O>7%g5Ps`j|E>SB@oy66&ux`#LhMp#LumPW7M5eb7=z<=XV;=P z!-dcbQnVMtg^CkM9FVvHXO7TNsf3kiE}$wm)L0xi@#fjMK}2fT+HYsSnVmOpzVZ8Z z>_Oar=<~^hUH5+exbAcA_+!HFJ9DYGvQBW$kOC=4Iak&VZf@Iia@GSLZaZ?`tPgxy zKlrl&2+WWeNsxkboD^I|$5|0m5&|Qfbu(5z!YVsohu22A%g0>)Bd#E41B@Lx!iG4j zFn06^8?FV_(9U0K&}O+Wmv#D*R;3Ng{$iars^+4wE<5^Fki$Lm1L3fbsn-PoS0j7oz5}eR7tkN3O!|;WZcua`=N_UYA1O+_bLI9r-X%6C- z8Cte_)pB6F+j`#eZIsk{OT#akV1q@-l$DozBOD)j0S_VW-!})n9Sffn} zOfaoks)5#Sl~_)>!jc#)SJ|-ScB?wDdGplSV|#rPm*c(0@q=Xq9O8Z*C@h~T!2a}w zIgeK7S;PwrEc<3dD;sbTiRTcwXe($vkFbhx2?6&GFCd^Yme|zZ0|{TRAfyqF(;6Gc zyBBAM9%71Y6Z3D^J-@niTZ(pLw@!S*#%{^{#XaZXgK(P7S<1X=_jH^Cl|`yKcntP^bMEb<^8{SRykzQjeyB@lyjTxjm7Us z@jWGcJ2JnksCOl`KiU~nk}EarDsbI>je7H4lvCHKp+9CldHkFgbzC#o(Ti}w{866j z_iNpJAYXlF6+O7@U>)Zp_I+ITtN1y>D!X+j1GZUj;KN*hf9T1p5%Vk0VuznSmUjzJ zP(67GcL;9gx0C_(P$a<79EQ?ndrOChlj&nlehT#u1=jOH)A?3;MgKiKEZ4;?tyFC{ z=nA~SqVX@TZ!q{%6a-95f$Yq4r=o2pf4C5I+;6<0+eps`tNG)==(#@7rb@ayLOLyaM1AVbLea}trbEan3_naEuB<%2neC^?m=xey5rp#(}J!=FNi zRParBAsj-ai?ad7%73t7&ITE){J};zt1>pkS-%^iyZJsQ?plm!-876f1v`P+qGg_` zmCOpA$Mf`!9KbK>I({x03W+XyG?g|yE5akB#ptMa5{Iej9h`;woS;6ZXN|eb7qVHM zMBw{^KnrK}Y@T$(ISOZQxoVRb*IcjFDntZL58d$g;spKdz1S530s!1Clb#byc}v(! z_Hv%SkPlHTPU1M762}?|M(UbGlAufh^aAt&@W=xo?FQHb0QPRAVUYI!f;bZ__`Vgi z9f@tk4{Sz8-U=hzk?58XbsfNVf&JSL8w1^tefY`KFQS1Ph1B`UwyrItQK4@FyGPPr zNEJk*2O|B4^#Hd{fVGh%gNuTeA{skFB`g9TJ2?UppTHPMM*+qeIO2^;t!R@8CSN1aFn2l`kfs2R0UQT70RYSHc&oKL zM1bQoz)65k^Mv;B9zlyNkvZB0WH7;tt$g=>KUH;v68vP3xm`@bNT;~t5b$7xC1~i%g=yqt&&v*>iW#WB1(0&XXU$POKl6mzxE6&iwkT@$A%btAiXm-oCriFk$sq&k)@4)w(e`DbsK|o=T3gw{#_OS1|z8w#2Z$rZHqiP9489%q4TpCp+ch)3kBqp;6_s*8qGul;4{$;Nb``6uoD&I@T%Wc z43jOnW?N8P#^tT3sQ-vb8qlPC`V`vO#-6Hj!}Rl2x030ctNNd0*)+<72s;5(bA-y_ zv>b)S(hvxII?6S|s1ahD>@7yVxfwQV?-;7&Q)W|9IT4CxX6I-XPASzEe4}N$i(5dL zL5Q$2H658~@hQyKLf0IF$nDi9bCj1_GzEOyZ*JM0`Il%01bR;tyD z>j}lMFd?uJJs2uJMj-Bk{}0~=XcjzBXH$f2sHvGu>Zf#(YC)LB#qwtwGmSs}#a}2) ze3H85Leem8(l|kw(ihL7YT=e{7&lG{lckI}Yrtr~j(OUhI>V%eRW?>@@}i;b{SsL%wg@Fcz`Ihnhxo*RPo!xW=E3oqR6`_{Nv2gD94^!Xa@9b?$(XD< zLt27_lz1|hkgpHSGMo-O6MQ9f!v=NA>ijyNkvvl zU>KsuAE>#d&7;{2h?E4tRpDOYqRo53)-iAEIAVCq(l~EvyX0+sLVi5`Y z5!0p0+GClAT}Os5Hnm=8T0h^kKEGiyzjN;Orn@dQ+;wF0#j@(7fd>L7hkjGm@qt08 zYW%ZNu(jt01M?Npd`T2UE=vBPU=kl1E*V}vN=}=}@G|3ROY`t@azlV|u{^-v7X@Kg zA;gl9)Wg4_QJ5}75N-(ro2z$hlE?}oiL1#786Zi+l-L3D@d-nwY;R3~5$&1M-7wOh zSV)uG+3AK_ReS2Zv)x(l~8no#U=9uP|kVe?u`H ziocK!_Toz=Ll0<7>e%l7dUmW{-^?>(v`Lt*EvlG=@F@rnq~Od+L1+{fR>M}fYuJ1B z73@@RxtKIDXP~is%9N-mPFzqFC!6@PP3%~&%PR{Xu?}SjOY#<{O}4V*ORC`>JnL$$ zG6pE-j*qA&`JPB36P0uCx?DP=c1E&tPumjhPy!{2>7o_2@#f+2&7&hU#kP7{vU@?U z8ly9qvi#sEl=i(84)xoiaE4Of!bbn&XcMl~}k#Is;*R4%H3Wk)c0Ds4P zy=1FBVL5KeH~L@ceQh9j&2x-o`RP_o>p+K}VKj^sua3nXU|!Kf2zW)PmO{m)!^(c` z^gE1W3;U#BKsJSNKLYp47lFb$M8RSik~I6Pw=w$wO3@#@=5}i(2lpI`4+5x5{|_g8 z9_(;XAyxuS5A<&8j!Bm zfjQL-8Byr6BE>??`C6Z-`h+&>?Umm$&`-m*+2K_;*FK64vnJxn{cvkW_J%3k1$2zP zwW`O+gB!}#uHH#3TCRdFJ;Xk@#%XLsa*&-`Q}4uZs+Ac<4$aZjo;aMUytktxEWgIH znIEfTsJ$8WEQMx*_kH+oiZ!Yl#KMOOWe$}wug+f+t&6PZ$gz` z_NBG`n*Ar)xwTuhp5Z>$u&#qVof}#AdrOv=i9t!o@*F*jbIt*%y4iS=o=0}Ew0MEM zh%#P7JSaZ_7TSUhD8&j^ZIKM6ax#tKh{5%wsHO2F6yYq$XfJ47-x~re(&yNB{XXK) z{mTE2x$hwmeX10jJemwAmeUi!Q#dpQo-_0jq`r>u4TNtZoMP`zdC9k#bzAogFXU%X z>OtT!;>q$VvOh=YLEt<44pOfnd|dFf{DaZOwl18e&$CCj^^hogZCgkC`#5Mn5KhFS z(Di9~7gACg8j%+}P#R{9cQlh9v5j~5vuE&SJlg1Wt$)Qkihd8_C4_|#Xr1;?QO1*k z?+R}n`Z5A0E`1i^6$ETIwNQ;9^)m!aHSH*@L&_n*M;iSFk{*PQIOh``p+{w&t$4Qj z%6YKWvSN(*mGC~3w>J;)on1!tClRh7yoqoW;VS?NCT`)#)PA5@d13n&BInBPyv>sNF1dZ2uUz*RF!M1TNvb@1~6AIBNDsMsKq_Hz#O{Nf~0h z!Y|$d9O!{4B4$d!1Tx#15?Sx!YsF>y`n0q4$GGIP2>)ikjI7+F>hWuy_s;ZhAXZJ` zIGiToS?+YTyufv*56*$3F{jk3a5Ovzr-sr!dQU{2!>@jR{;G8g1F=*JyY2!{E71AQnWnsme!e?F=KEs%tt?(O(^G^aI HPow_=;crK{ delta 3811 zcmZ`+Yj9gf72eh3I+FahWy`iK$(HR%#8K@=P9Kh6XM>kHF(V=2iEECy0R1zv> zr5S*?eF#5zU`G{JkD!Z?^6=XV zzHjaTvNrdRf-zbdhV4I0!h}(oFoR5)QTFh##Y;VW#o`8h+|o=N_)jh4w2?clo0dJY zBWwr93c^<5rZrq1l;j$8V4kunxi(xHG{|+rq3YFUf30(!w+Vj`Zb#?@@S8Q0 zs-$Mw`^Z$>tp`6A?87JM%%n%Ipd*-o7+rpst!3_vT*DzUkAG8Rom7FlLtF{_$@ zz3S@C6(6A!xkDBIqMP^Nj0uEEoMw$NC7M;@k(mQ*FEIQ@CNv_bv6nw#Z_=Hc=I8Cc zsZudzArn_pk@(WWA}hlw7TuK_Ewiw&fG~{^1<;Ie8ma{-+?V)R5QQ-mHt>TEk3o&) z4pyC@p&4NL&3X{0QCyr*vXRx(gq&cw7p9<=QD1)!8M6SY5df^~w&(diM>Czzea&H` zcj#{6q;;mn@~jHDTEXb}$jIQxK;GtcC)RB8Kh_M;d-%<@sqPpql>a0kJ_WyPUxoyk zB3|NwnvzM{2tSjHOs=*FXbQ$QGEPU0#~dzC+B9y!WWg^A7IrtKR*RmHF)}LghMEdK z3qOy&lE2|9;rDuMwQ1w-U<#KL-#7_UnL_bwNSnE3m&Z45ZY1NTd`rsz?;(>F`6*%t zCO+lz*H`9w1tj_?oEjN7&OAbh>&DfeYOq7!`K_z2NSbuyXN=bJb8P9io6j}INu5!NOF1TSaFauliN_w^MGH*-0sIY0! zX7dOlSkj$+gFE|oumwKmZSrB{H7P!WIU*8bHDoIy4K#@@rJlzXz9A{g6rTTp=suVwwH0 zbg|BTc7x-psmW6Qy4!!wz4L5q_pe(A&$SMoZymy6q4SNBs*8b+3#FU*>6U*lTj-I# z;_Ys7vYKwMH=T6!nP70rN&7AIR6Rx7UDPibPi^+?a2rnB=^YN~bXA`L24@^}N1b%0 z#*DO!g3=kcyPulRbQ^$vg;1bhp=P8dBbd{yYvmftr**E&_pR~|xR&0&{p&`JM)+4c z@1)20?>Z}?Kfm4CCh1k1Yu&b&_Uk!;XM~5LRiu*XyP=AtYb%PK6AjO$Gmqo$dW zuVNR@uBotn{FhzeJ^w?OXL%Ev>PHCS1e@OD_1p9qvav!m^WbpbKtF@Z&~F-T16rI- zrWDqKJkg@>L(TGtnp{{)K~Kd#oM4KoGQ4Wn-@m8_tuV{(Cm#aIr&0VFgwG-rAz*^9 zrH^KWyPbcxyFr>p{(k;y_b%xOl0M$r(<+rCS;41zoU3QWS9%t`cpGX~RY}Jqi*uQ5 zh83dP!vOvoy=5Na6}^qLo^S1KqXRtAyG3e4m3IDkZ-;LGQT|%*E-8rICEniWqmSh# z`hIU468UddQ;M?4eu&Z^0cggBWSX7C;abgz%>N$pL~)4tJqttVz~Tx!#19WQ*xP}# z{&q(FC;3ame%hIPYxuG?co;aJu)rOvuyiza6Z;%$d>-Kega;8GLihs07ZHwfx~GwT zi8t*D%!nfVCUV!_LE?@*g}k321Q5jLzlziug!>WR6QApM$lB!@&9Z0sb9>t848Od` zN8jYd6Lq$=OpWm1gqJ?WcTa?te}EQ6=tRzlmjYfR>>CKrBV0dE{ZaD*@QLaR!q$#U?CTkj>;3gT31-q0OuP%k@*>G_4lWF zqOa*e>UM-K0Ka+UiWLV2fRPcu0eeZ;@HCD^qS-7< z&MblPNQBLxavwqv;g&q?0788z*Q>PBjXrbv*`8fj3I0DBI~z}&o17D)tAs4k0mDC0 za&?Sfo9%H5Q)tECu6MW-Xi=o+-}p%4mM->ZU}y!=Bs99{f^c6eS{((l2NuC)>^WL_ zG#*`q%2fmX`(w%?zUf&J|3#vn7V|$P8kdFB;vmXUy->U`)j%J<$o140Szn35vQ^>0 z^*;~txxH?tl3rR+SvBZ02SN?bv@m5cqXy#e*8>JmtH|$aZ$RA=SsjRIY8L3uJ2Q!d z!r1?CeiMR-sgVNJfkSb;Hy*jRVH@mjDW!C?O3;IUM16t(WzJ)NjTBxc)tAVoOXSu| Zr2RJ}c8 List[Generation]: return await self.dao.generations.get_generations(status=GenerationStatus.RUNNING, created_by=user_id, project_id=project_id) - async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None) -> GenerationResponse: + async def create_generation_task(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationGroupResponse: + count = generation_request.count + + if generation_group_id is None: + generation_group_id = str(uuid4()) + + results = [] + for _ in range(count): + gen_response = await self._create_single_generation(generation_request, user_id, generation_group_id) + results.append(gen_response) + return GenerationGroupResponse(generation_group_id=generation_group_id, generations=results) + + async def _create_single_generation(self, generation_request: GenerationRequest, user_id: Optional[str] = None, generation_group_id: Optional[str] = None) -> GenerationResponse: gen_id = None generation_model = None try: - generation_model = Generation(**generation_request.model_dump()) + generation_model = Generation(**generation_request.model_dump(exclude={'count'})) if user_id: generation_model.created_by = user_id + if generation_group_id: + generation_model.generation_group_id = generation_group_id gen_id = await self.dao.generations.create_generation(generation_model) generation_model.id = gen_id diff --git a/models/Generation.py b/models/Generation.py index 6c74100..9a75c84 100644 --- a/models/Generation.py +++ b/models/Generation.py @@ -35,6 +35,7 @@ class Generation(BaseModel): output_token_usage: Optional[int] = None is_deleted: bool = False album_id: Optional[str] = None + generation_group_id: Optional[str] = None created_by: Optional[str] = None # Stores User ID (Telegram ID or Web User ObjectId) project_id: Optional[str] = None created_at: datetime = Field(default_factory=lambda: datetime.now(UTC)) diff --git a/models/__pycache__/Generation.cpython-313.pyc b/models/__pycache__/Generation.cpython-313.pyc index b052293dfd5f8e05d03ae7ee0feab8dd2fb2bb8d..cb9ecdb2d77fb9c455b1c77b3cb8f4ea7aa01bbf 100644 GIT binary patch delta 306 zcmX>i*(k;PnU|M~0SLG+^<~;_vEXSD*B)d3Ii8_JAnL&gL yhyXdi$Y*jMml|XC=BZp#jFb0qUt;u|JdsD8F>>-w9!q{jM&SvPUm1WD*ct!?r!$BE diff --git a/repos/__pycache__/generation_repo.cpython-313.pyc b/repos/__pycache__/generation_repo.cpython-313.pyc index 7ef0c95795a538ea8205ef9fb45e67414a306b5c..34fb78ee00d29c6f06f9e6a6a106d024003dfc50 100644 GIT binary patch delta 739 zcmZ9I&ref95XWcsz1Nl>ZLx0~ORG(*M*27?qP18^kOmAvY$KZ&qcj$Ol%^gEZi~j4 zQ1zr26Wtj94GkBL5dVQz4kb+pUN~@T#H8`&uI7cr@8L7w&Cblu?2qC0i0{bj?FA|8 z`}=al|J4^nnm~Y0kOz(oIAO>&HYid!iDB^^Z(>Az!_{cs&5XP^?=ocV6`h?7!xNnc zScsdY9_!S`HIf+UZ+l97Wo@-ke*B1t6K+akC?dw_u#F5Exj>TgA*BV?<6>w%~G*a@qQF(n?9t{o~BYlY)k zs)dIy{>{@K60iMd?7y?!R)rrW4&fE-($MGJhg^#$nl!N$YTUiwvbgw_1>A&ed_SLswRCig&7}iYnbWRxAm!> zj87~D!v1^mDx0;8tdd2=yTD7r?+7%_HgmH4N_ztc_ph}<(48^LG^(ABkT^Jcl7(X$9sUiuWH~`MqA|(I- diff --git a/repos/generation_repo.py b/repos/generation_repo.py index c668e97..b561548 100644 --- a/repos/generation_repo.py +++ b/repos/generation_repo.py @@ -77,3 +77,11 @@ class GenerationRepo: async def update_generation(self, generation: Generation, ): res = await self.collection.update_one({"_id": ObjectId(generation.id)}, {"$set": generation.model_dump()}) + + async def get_generations_by_group(self, group_id: str) -> List[Generation]: + res = await self.collection.find({"generation_group_id": group_id, "is_deleted": False}).sort("created_at", 1).to_list(None) + generations: List[Generation] = [] + for generation in res: + generation["id"] = str(generation.pop("_id")) + generations.append(Generation(**generation)) + return generations