FLUX (Black Forest Labs) per pastaruosius kelis mėnesius audringai užvaldė dirbtinio intelekto vaizdų generavimo pasaulį. Jis ne tik įveikė Stable Diffusion (ankstesnį atvirojo kodo karalių) pagal daugelį etalonų, bet ir kai kuriais rodikliais pranoko patentuotus modelius, tokius kaip Dall-E ar Midjourney.
Bet kaip jūs ketinate naudoti FLUX vienoje iš savo programų? Galima manyti, kad reikia naudoti be serverio pagrindinius kompiuterius, pvz., „Replicate“ ir kitus, tačiau jie gali labai greitai pabrangti ir nesuteikti reikiamo lankstumo. Čia praverčia susikurti savo FLUX serverį.
Šiame straipsnyje paaiškinsime, kaip sukurti savo FLUX serverį naudojant Python. Šis serveris leis generuoti vaizdus pagal tekstinius raginimus naudojant paprastą API. Nesvarbu, ar naudojate šį serverį asmeniniam naudojimui, ar diegiate jį kaip gamybinės programos dalį, šis vadovas padės jums pradėti.
Būtinos sąlygos
Prieš pasinerdami į kodą, įsitikinkime, kad turite būtinus įrankius ir bibliotekas:
- Python: jūsų kompiuteryje reikės įdiegti Python 3, pageidautina 3.10 versiją.
torch
: giluminio mokymosi sistema, kurią naudosime vykdydami FLUX.diffusers
: suteikia prieigą prie FLUX modelio.transformers
: Reikalinga difuzorių priklausomybė.sentencepiece
: reikalingas norint paleisti FLUX prieigos raktąprotobuf
: reikalingas norint paleisti FLUXaccelerate
: kai kuriais atvejais padeda efektyviau įkelti FLUX modelį.fastapi
: sistema, skirta sukurti žiniatinklio serverį, galintį priimti vaizdų generavimo užklausas.uvicorn
: reikalingas norint paleisti „fastapi“ serverį.psutil
: leidžia patikrinti, kiek RAM yra mūsų kompiuteryje.
Galite įdiegti visas bibliotekas vykdydami šią komandą: pip install torch diffusers transformers accelerate fastapi uvicorn psutil
.
Pastaba „MacOS“ naudotojams: jei naudojate „Mac“ kompiuterį su M1 arba M2 lustu, turėtumėte nustatyti „PyTorch“ su „Metal“, kad užtikrintumėte optimalų veikimą. Prieš tęsdami vadovaukitės oficialiu „PyTorch with Metal“ vadovu.
1 veiksmas: aplinkos nustatymas
Pradėkime scenarijų pasirinkdami tinkamą įrenginį, kad būtų galima daryti išvadas pagal mūsų naudojamą aparatinę įrangą.
import torch
device = 'cuda' # can also be 'cpu' or 'mps'
if device == 'mps' and not torch.backends.mps.is_available():
raise Exception("Device set to MPS, but MPS is not available")
elif device == 'cuda' and not torch.cuda.is_available():
raise Exception("Device set to CUDA, but CUDA is not available")
Galite nurodyti cpu
, cuda
(skirta NVIDIA GPU) arba mps
(skirta Apple Metal Performance Shaders). Tada scenarijus patikrina, ar pasirinktas įrenginys yra prieinamas, ir pateikia išimtį, jei jos nėra.
2 veiksmas: FLUX modelio įkėlimas
Toliau įkeliame FLUX modelį. Įkelsime modelį fp16 tikslumu, o tai sutaupys šiek tiek atminties neprarandant kokybės.
Pastaba: šiuo metu jūsų gali būti paprašyta autentifikuoti HuggingFace, nes FLUX modelis yra uždarytas. Kad autentifikavimas būtų sėkmingas, turėsite sukurti HuggingFace paskyrą, eiti į modelio puslapį, sutikti su sąlygomis, tada paskyros nustatymuose sukurti HuggingFace prieigos raktą ir pridėti jį savo įrenginyje kaip HF_TOKEN aplinkos kintamąjį.
from diffusers import DDIMScheduler, FluxPipeline
import psutil
model_name = "black-forest-labs/FLUX.1-dev"
print(f"Loading {model_name} on {device}")
pipeline = FluxPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16,
use_safetensors=True
).to(device)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
Čia mes įkeliame FLUX modelį naudodami difuzorių biblioteką. Mūsų naudojamas modelis yra black-forest-labs/FLUX.1-dev
įkeltas fp16 tikslumu. Taip pat yra FLUX pro modelis, kuris yra stipresnis, bet, deja, nėra atvirojo kodo, todėl jo negalima naudoti.
Čia naudosime DDIM planuoklį, bet taip pat galite pasirinkti kitą, pvz., Euler arba UniPC. Daugiau apie tvarkaraščius galite perskaityti čia.
Kadangi vaizdų generavimas gali pareikalauti daug išteklių, labai svarbu optimizuoti atminties naudojimą, ypač kai naudojamas CPU arba įrenginys su ribota atmintimi.
# Recommended if running on MPS or CPU with < 64 GB of RAM
total_memory = psutil.virtual_memory().total
total_memory_gb = total_memory / (1024 ** 3)
if (device == 'cpu' or device == 'mps') and total_memory_gb < 64:
print("Enabling attention slicing")
pipeline.enable_attention_slicing()
Šis kodas tikrina visą laisvą atmintį ir įgalina dėmesį, jei sistemoje yra mažiau nei 64 GB RAM. Dėmesys sumažina atminties naudojimą generuojant vaizdą, o tai būtina įrenginiams su ribotais ištekliais.
3 veiksmas: API sukūrimas naudojant FastAPI
Tada nustatysime FastAPI serverį, kuris suteiks API vaizdams generuoti.
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, conint, confloat
from fastapi.middleware.gzip import GZipMiddleware
from io import BytesIO
import base64
app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=7)
FastAPI yra populiari žiniatinklio API kūrimo sistema naudojant Python. Šiuo atveju mes naudojame jį norėdami sukurti serverį, kuris gali priimti užklausas dėl vaizdo generavimo. Atsakymui suspausti taip pat naudojame GZip tarpinę programinę įrangą, kuri ypač naudinga siunčiant vaizdus atgal base64 formatu.
Pastaba: gamybinėje aplinkoje galbūt norėsite saugoti sugeneruotus vaizdus S3 segmente arba kitoje debesies saugykloje ir pateikti URL adresus, o ne base64 koduotas eilutes, kad galėtumėte pasinaudoti CDN ir kitais optimizavimais.
4 veiksmas: užklausos modelio apibrėžimas
Turime apibrėžti užklausų, kurias priims mūsų API, modelį.
class GenerateRequest(BaseModel):
prompt: str
negative_prompt: str
seed: conint(ge=0) = Field(..., description="Seed for random number generation")
height: conint(gt=0) = Field(..., description="Height of the generated image, must be a positive integer and a multiple of 8")
width: conint(gt=0) = Field(..., description="Width of the generated image, must be a positive integer and a multiple of 8")
cfg: confloat(gt=0) = Field(..., description="CFG (classifier-free guidance scale), must be a positive integer or 0")
steps: conint(ge=0) = Field(..., description="Number of steps")
batch_size: conint(gt=0) = Field(..., description="Number of images to generate in a batch")
Šis GenerateRequest modelis apibrėžia parametrus, reikalingus vaizdui generuoti. Raginimas yra tekstinis vaizdo, kurį norite sukurti, aprašymas. Neigiamas_prompt gali būti naudojamas norint nurodyti, ko vaizde nenorite. Kiti laukai apima vaizdo matmenis, išvados žingsnių skaičių ir partijos dydį.
5 veiksmas: vaizdo generavimo pabaigos taško sukūrimas
Dabar sukurkime galinį tašką, kuris tvarkys vaizdo generavimo užklausas.
@app.post("https://www.sitepoint.com/")
async def generate_image(request: GenerateRequest):
if request.height % 8 != 0 or request.width % 8 != 0:
raise HTTPException(status_code=400, detail="Height and width must both be multiples of 8")
generator = (torch.Generator(device="cpu").manual_seed(i) for i in range(request.seed, request.seed + request.batch_size))
images = pipeline(
height=request.height,
width=request.width,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
generator=generator,
num_inference_steps=request.steps,
guidance_scale=request.cfg,
num_images_per_prompt=request.batch_size
).images
base64_images = ()
for image in images:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
base64_images.append(img_str)
return {
"images": base64_images,
}
Šis galutinis taškas tvarko vaizdo generavimo procesą. Pirmiausia patvirtinama, kad aukštis ir plotis yra 8 kartotiniai, kaip reikalauja FLUX. Tada jis generuoja vaizdus pagal pateiktą raginimą ir grąžina juos kaip base64 koduotas eilutes.
6 veiksmas: paleiskite serverį
Galiausiai pridėkime kodą serveriui paleisti, kai vykdomas scenarijus.
@app.on_event("startup")
async def startup_event():
print("Image generation server running")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)
Šis kodas paleidžia FastAPI serverį prie 8000 prievado, todėl jį galima pasiekti iš http://localhost:8000
.
7 veiksmas: patikrinkite savo serverį vietoje
Dabar, kai jūsų FLUX serveris yra paruoštas ir veikia, laikas jį išbandyti. Norėdami bendrauti su serveriu, galite naudoti „curl“ – komandinės eilutės įrankį HTTP užklausoms teikti:
curl -X POST "http://localhost:8000/" \
-H "Content-Type: application/json" \
-d '{
"prompt": "A futuristic cityscape at sunset",
"negative_prompt": "low quality, blurry",
"seed": 42,
"height": 512,
"width": 512,
"cfg": 7.5,
"steps": 50,
"batch_size": 1
}'
Išvada
Sveikiname! Naudodami Python sėkmingai sukūrėte savo FLUX serverį. Ši sąranka leidžia generuoti vaizdus pagal tekstinius raginimus naudojant paprastą API. Jei nesate patenkinti bazinio FLUX modelio rezultatais, galite patobulinti modelį, kad būtų dar geresnis našumas arba konkretūs naudojimo atvejai.
Pilnas kodas
Visą šiame vadove naudojamą kodą galite rasti žemiau:
import torch
device = 'cuda'
if device == 'mps' and not torch.backends.mps.is_available():
raise Exception("Device set to MPS, but MPS is not available")
elif device == 'cuda' and not torch.cuda.is_available():
raise Exception("Device set to CUDA, but CUDA is not available")
from diffusers import DDIMScheduler, FluxPipeline
import psutil
model_name = "black-forest-labs/FLUX.1-dev"
print(f"Loading {model_name} on {device}")
pipeline = FluxPipeline.from_pretrained(
model_name,
torch_dtype=torch.float16,
use_safetensors=True
).to(device)
pipeline.scheduler = DDIMScheduler.from_config(pipeline.scheduler.config)
total_memory = psutil.virtual_memory().total
total_memory_gb = total_memory / (1024 ** 3)
if (device == 'cpu' or device == 'mps') and total_memory_gb < 64:
print("Enabling attention slicing")
pipeline.enable_attention_slicing()
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel, Field, conint, confloat
from fastapi.middleware.gzip import GZipMiddleware
from io import BytesIO
import base64
app = FastAPI()
app.add_middleware(GZipMiddleware, minimum_size=1000, compresslevel=7)
class GenerateRequest(BaseModel):
prompt: str
negative_prompt: str
seed: conint(ge=0) = Field(..., description="Seed for random number generation")
height: conint(gt=0) = Field(..., description="Height of the generated image, must be a positive integer and a multiple of 8")
width: conint(gt=0) = Field(..., description="Width of the generated image, must be a positive integer and a multiple of 8")
cfg: confloat(gt=0) = Field(..., description="CFG (classifier-free guidance scale), must be a positive integer or 0")
steps: conint(ge=0) = Field(..., description="Number of steps")
batch_size: conint(gt=0) = Field(..., description="Number of images to generate in a batch")
@app.post("https://www.sitepoint.com/")
async def generate_image(request: GenerateRequest):
if request.height % 8 != 0 or request.width % 8 != 0:
raise HTTPException(status_code=400, detail="Height and width must both be multiples of 8")
generator = (torch.Generator(device="cpu").manual_seed(i) for i in range(request.seed, request.seed + request.batch_size))
images = pipeline(
height=request.height,
width=request.width,
prompt=request.prompt,
negative_prompt=request.negative_prompt,
generator=generator,
num_inference_steps=request.steps,
guidance_scale=request.cfg,
num_images_per_prompt=request.batch_size
).images
base64_images = ()
for image in images:
buffered = BytesIO()
image.save(buffered, format="PNG")
img_str = base64.b64encode(buffered.getvalue()).decode("utf-8")
base64_images.append(img_str)
return {
"images": base64_images,
}
@app.on_event("startup")
async def startup_event():
print("Image generation server running")
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)