Priešpriešinis mokymasis su "PyTorch Lightning”
Priešpriešinis mokymasis tapo populiariu ir galingu gilaus mokymosi modelių mokymo metodu, ypač generatyvinio modeliavimo srityje. Šiame kontekste priešpriešinis mokymasis paprastai apima dviejų modelių mokymą vienu metu: generatoriaus ir diskriminatoriaus. Generatoriaus tikslas - sukurti realius pavyzdžius, panašius į tikslinį duomenų pasiskirstymą, o diskriminatoriaus tikslas - teisingai atskirti realius ir sugeneruotus pavyzdžius. Tokia konfigūracija sukuria dinamišką dviejų modelių konkurenciją, skatinančią tobulinti abu modelius.
"PyTorch Lightning" yra galinga ir lanksti sistema, kuri supaprastina gilaus mokymosi modelių mokymo procesą naudojant "PyTorch". Tačiau, kai kalbama apie priešpriešinį mokymąsi, darbo eiga šiek tiek skiriasi nuo įprasto vieno modelio mokymo proceso. Šiame straipsnyje aptarsime, kaip veiksmingai įgyvendinti priešpriešinį mokymąsi "PyTorch Lightning", daugiausia dėmesio skirdami unikaliems dviejų modelių mokymo vienu metu aspektams. Nagrinėsime, kaip apibrėžti generatoriaus ir diskriminatoriaus modelius, taip pat kaip tinkamai nustatyti mokymo ciklą, optimizatorius ir nuostolių funkcijas "PyTorch Lightning" sistemoje. Taip pat aptarsime geriausią praktiką, kaip stebėti ir vertinti abiejų modelių našumą mokymo metu. Šio puslapio pabaigoje jau gerai suprasite, kaip pasinaudoti "PyTorch Lightning" galimybėmis, kad galėtumėte veiksmingai mokyti priešingus modelius ir išnaudoti generatyvinių priešingų tinklų (GAN) galią savo gilaus mokymosi projektuose.
Generatoriaus ir diskriminatoriaus modelių apibrėžimas
Priešingo mokymosi atveju generatoriaus ir diskriminatoriaus modeliai yra du pagrindiniai komponentai. Šiame skyriuje aptarsime, kaip sukurti šių modelių architektūras ir juos inicializuoti.
Generatoriaus modelio kūrimas
Generatoriaus modelis yra atsakingas už naujų duomenų pavyzdžių kūrimą. Paprastai jis priima atsitiktinį triukšmą kaip įvestį ir generuoja duomenų pavyzdžius, panašius į tikslinį pasiskirstymą. Norėdami sukurti generatoriaus architektūrą, kaip bazinę klasę galite naudoti "PyTorch" klasę nn.Module:
class Generator(nn.Module):
def __init__(self):
super(Generator, self).__init__()
# Define the generator layers and architecture here
def forward(self, x):
# Implement the forward pass
return x
Konkreti generatoriaus architektūra priklausys nuo jūsų probleminės srities ir duomenų rinkinio. Vaizdams generuoti galite naudoti konvoliucinius sluoksnius, tekstui generuoti - LSTM arba transformatorių sluoksnius ir t. t.
Diskriminatoriaus modelio kūrimas
Diskriminatoriaus modelis yra atsakingas už tikrų ir sugeneruotų pavyzdžių atskyrimą. Jis priima duomenų imtį kaip įvestį ir išveda tikimybę, kad imtis yra tikra. Panašiai kaip ir generatoriaus modelį, diskriminatoriaus architektūrą galite sukurti naudodami PyTorch nn.Module klasę:
class Discriminator(nn.Module):
def __init__(self):
super(Discriminator, self).__init__()
# Define the discriminator layers and architecture here
def forward(self, x):
# Implement the forward pass
return x
Diskriminatoriaus architektūra taip pat priklauso nuo probleminės srities ir duomenų rinkinio. Atliekant vaizdų klasifikavimo užduotis galima naudoti konvoliucinius sluoksnius, o atliekant teksto klasifikavimo užduotis - LSTM arba transformatorių sluoksnius.
"PyTorch Lightning" modulio įgyvendinimas
Šiame skyriuje aptarsime, kaip įgyvendinti PyTorch Lightning modulį, skirtą priešpriešiniam mokymuisi. Tam reikės apibrėžti forward, training_step, validation_step ir test_step metodus, taip pat nustatyti generatoriaus ir diskriminatoriaus optimizatorius ir tvarkaraščius.
"LightningModule" apžvalga
"LightningModule" yra pagrindinė "PyTorch Lightning" klasė, kuri apima modelį, optimizatorius ir mokymosi logiką. Ji supaprastina mokymo procesą, nes automatizuoja daugelį užduočių, pavyzdžiui, GPU paskirstymą, kontrolinių taškų nustatymą ir registravimą. Priešpriešiniam mokymuisi sukursime pasirinktinį "LightningModule", skirtą generatoriaus ir diskriminatoriaus modelių sąveikai tvarkyti.
Modelių inicijavimas
Apibrėžę generatoriaus ir diskriminatoriaus architektūras, dabar galite inicializuoti modelius savo GAN klasės, iš pl.LightningModule, __init__ metodu:
class GAN(pl.LightningModule):
def __init__(self, hparams: HParams):
super(GAN, self).__init__()
self.generator = Generator()
self.discriminator = Discriminator()
self.hparams = hparams
self.automatic_optimization = False
Svarbu pažymėti, kad mokant GAN su PyTorch Lightning nustatome self.automatic_optimization = False, kad galėtume geriau kontroliuoti generatoriaus ir diskriminatoriaus modelių optimizavimo procesą. Nors automatinis optimizavimas tinka daugumai vieno modelio užduočių, GAN reikalauja pakaitinių generatoriaus ir diskriminatoriaus atnaujinimų kartu su skirtingomis kiekvieno modelio nuostolių funkcijomis. Dėl to reikia labiau kontroliuoti optimizavimo procesą.
Jei parametrui self.automaticoptimization nustatyta reikšmė False, "PyTorch Lightning" leidžia rankiniu būdu atnaujinti generatoriaus ir diskriminatoriaus modelius taikant training_step metodą, todėl lengviau įgyvendinti specifinę mokymo dinamiką, reikalingą priešpirešiniam mokymuisi, pavyzdžiui, atnaujinti modelius skirtingu greičiu arba taikyti skirtingas nuostolių funkcijas.
Nors automatinis optimizavimas gali būti naudojamas keliems "PyTorch Lightning" modeliams, jis ne visada idealiai tinka priešpriešiniams tinklams. Rankinis optimizavimas užtikrina geresnę GAN mokymo proceso kontrolę ir paprastai jį yra paprasčiau įgyvendinti.
Išankstinio metodo apibrėžimas
"LightningModule" metodas "forward" naudojamas generatoriaus perdavimui į priekį apibrėžti. Priešpriešinio mokymosi atveju paprastai kaip įvestis imama triukšmo vektorių partija ir sukuriami sugeneruoti pavyzdžiai:
class GAN(pl.LightningModule):
...
def forward(self, noise):
return self.generator(noise)
Nuostolių funkcijų apibrėžimas
Priešpriešinio mokymosi atveju generatoriaus ir diskriminatoriaus modeliams naudojamos skirtingos nuostolių funkcijos. Dažniausiai naudojamos šios nuostolių funkcijos: dvejetainė kryžminė entropija (BCE), Vaseršteino nuostoliai ir “Hinge” nuostoliai. Šios nuostolių funkcijos skirtingai įvertina generatoriaus ir diskriminatoriaus modelių našumą, o nuostolių funkcijos pasirinkimas gali turėti didelę įtaką mokymo dinamikai ir galutiniam modelio našumui.
Pavyzdžiui, norėdami įgyvendinti dvejetainio kryžminio entropijos nuostolio funkciją, galite apibrėžti atskiras nuostolių funkcijas generatoriui ir diskriminatoriui taip:
def disc_loss(self, real_preds, fake_preds):
real_loss = F.binary_cross_entropy_with_logits(real_preds, torch.ones_like(real_preds))
fake_loss = F.binary_cross_entropy_with_logits(fake_preds, torch.zeros_like(fake_preds))
return real_loss + fake_loss
def gen_loss(self, fake_preds):
return F.binary_cross_entropy_with_logits(fake_preds, torch.ones_like(fake_preds))
Generatoriaus ir diskriminatoriaus optimizatorių ir tvarkaraščių sudarytojų nustatymas
Metodas configure_optimizers naudojamas atskiriems generatoriaus ir diskriminatoriaus optimizatoriams ir tvarkaraščiams nustatyti:
class GAN(pl.LightningModule):
...
def configure_optimizers(self):
optim_g = torch.optim.AdamW(
self.generator.parameters(),
self.hparams.learning_rate,
betas=self.hparams.betas,
eps=self.hparams.eps)
optim_d = torch.optim.AdamW(
self.discriminator.parameters(),
self.hparams.learning_rate,
betas=self.hparams.betas,
eps=self.hparams.eps)
scheduler_g = torch.optim.lr_scheduler.ExponentialLR(optim_g, gamma=self.hps.train.lr_decay)
scheduler_d = torch.optim.lr_scheduler.ExponentialLR(optim_d, gamma=self.hps.train.lr_decay)
return [optim_g, optim_d], [scheduler_g, scheduler_d]
Mokymo strategijos įgyvendinimas
Priešpriešinio mokymosi atveju generatoriaus ir diskriminatoriaus modeliai mokomi iteratyviai, o kiekvieno modelio svoriai atnaujinami remiantis anksčiau apibrėžtomis nuostolių funkcijomis. Training_step metodu yra įgyvendinama priešpriešinio mokymo logika. Juo apskaičiuojami generatoriaus ir diskriminatoriaus nuostoliai ir abiem modeliams atskirai taikomas rankinis optimizavimas:
def training_step(self, batch, batch_idx):
g_opt, d_opt = self.optimizers()
scheduler_g, scheduler_d = self.lr_schedulers()
real_samples, _ = batch
noise = torch.randn(real_samples.size(0), self.noise_dim, device=self.device)
fake_samples = self.generator(noise)
# Update the discriminator
real_preds = self.discriminator(real_samples)
fake_preds = self.discriminator(fake_samples.detach())
disc_loss = self.disc_loss(real_preds, fake_preds)
self.log('disc_loss', disc_loss)
d_opt.zero_grad()
self.manual_backward(disc_loss)
d_opt.step()
scheduler_d.step()
# Update the generator
fake_preds = self.discriminator(fake_samples)
gen_loss = self.gen_loss(fake_preds)
self.log('gen_loss', gen_loss)
g_opt.zero_grad()
self.manual_backward(gen_loss)
g_opt.step()
scheduler_g.step()
Atminkite, kad inicializuodami modelį nustatėme self.automaticoptimization = False. Kai self.automaticoptimization = False, generatoriaus ir diskriminatoriaus modelių optimizavimo procesą reikia atlikti rankiniu būdu. Pateiktame training_step procese lemiamą vaidmenį atlieka metodai zero_grad(), self.manual_backward(loss) ir step():
zero_grad(): Šiuo metodu išvalomi ankstesnio optimizavimo etapo gradientai. Tai būtina, kad gradientai nesusikauptų per kelis optimizavimo etapus, o tai galėtų lemti neteisingą svorio atnaujinimą.
self.manual_backward(loss): Šis metodas apskaičiuoja nuostolių funkcijos gradientus modelio parametrų atžvilgiu. Gradientai reikalingi modelio svoriams atnaujinti.
Be diskriminatoriaus ir generatoriaus nuostolių gradientų apskaičiavimo, manual_backward metodas taip pat pasirūpina mišraus tikslumo mokymu ir gradiento mastelio keitimu, jei jie įjungti PyTorch Lightning konfigūracijoje. Tai leidžia pasinaudoti mišraus tikslumo mokymo teikiamais našumo patobulinimais ir kartu išlaikyti rankinę priešingo mokymosi optimizavimo proceso kontrolę.
3. step(): Šis metodas atnaujina modelio svorius pagal apskaičiuotus gradientus ir optimizatoriaus mokymosi greitį. Tai paskutinis optimizavimo proceso žingsnis.
Validation_step ir test_step metodais galima įvertinti generatoriaus ir diskriminatoriaus veikimą atitinkamai pagal validavimo ir testavimo duomenis.
Generatoriaus ir diskriminatoriaus mokymo pusiausvyra
Vienas iš pagrindinių priešpriešinio mokymosi uždavinių - suderinti generatoriaus ir diskriminatoriaus modelių mokymą. Jei vienas modelis, palyginti su kitu, tampa pernelyg galingas, mokymo procesas gali tapti nestabilus arba gali nepavykti konverguoti. Norėdami sušvelninti šią problemą, apsvarstykite galimybę įgyvendinti toliau nurodytas strategijas:
Atnaujinimo dažniai: Sureguliuokite kiekvieno modelio atnaujinimų skaičių per epochą. Pavyzdžiui, diskriminatorių atnaujinkite kelis kartus kiekvienam generatoriaus atnaujinimui arba atvirkščiai. Tai gali padėti išvengti, kad vienas modelis netaptų pernelyg dominuojantis.
for _ in range(self.hparams.disc_update_freq):
d_opt.zero_grad()
self.manual_backward(disc_loss)
d_opt.step()
for _ in range(self.hparams.gen_update_freq):
g_opt.zero_grad()
self.manual_backward(gen_loss)
g_opt.step()
Mokymosi greičio planavimas: Naudokite skirtingus mokymosi greičio tvarkaraščius generatoriaus ir diskriminatoriaus modeliams. Įprasta naudoti mažesnį diskriminatoriaus mokymosi greitį, kad jis neužgožtų generatoriaus.
Gradiento apkarpymas arba “gradient penalty”: Taikykite gradiento apkarpymo arba “gradient penalty” metodus, kad stabilizuotumėte mokymo dinamiką. Gradiento apkarpymas apriboja didžiausią gradiento vertę, taip užkirsdamas kelią itin dideliems atnaujinimams, o “gradient penalty” į nuostolių funkciją įtraukia narį, skatinantį sklandesnius gradientus ir mažinantį nestabilių atnaujinimų tikimybę.
Spektrinis normalizavimas: Šis metodas normalizuoja diskriminatoriaus sluoksnių svorius, užtikrindamas Lipschitz tęstinumą ir padėdamas stabilizuoti mokymo procesą.
Naudokite skirtingas nuostolių funkcijas: Eksperimentuokite su skirtingais nuostolių funkcijų deriniais generatoriaus ir diskriminatoriaus modeliams. Pavyzdžiui, diskriminatoriui galite naudoti Wassersteino nuostolius su “gradient penalty”, o generatoriui - “Hinge” nuostolius.
Stebėsena ir ankstyvas sustabdymas: Stebėkite nuostolių vertes, modelio rezultatus ir kitus našumo rodiklius viso mokymo proceso metu. Jei generatoriaus arba diskriminatoriaus nuostolių vertės tampa per mažos arba per didelės, apsvarstykite galimybę koreguoti mokymosi greitį, atnaujinimo dažnius arba kitus hiperparametrus. Įgyvendinkite ankstyvo sustabdymo kriterijus, pagrįstus modelio našumu arba mokymo stabilumu.
Kruopščiai subalansavę generatoriaus ir diskriminatoriaus modelių mokymą, galite užtikrinti priešpriešinio mokymosi proceso stabilumą ir pasiekti geresnių rezultatų generuojant įvairius ir tikroviškus rezultatus.
Stebėsena ir vertinimas
GAN pažangos stebėjimas ir veikimo vertinimas yra esminiai priešpriešinio mokymosi proceso etapai. Šiame skyriuje aptarsime, kaip stebėti mokymo eigą, vizualizuoti generatoriaus išvestis ir diskriminatoriaus veikimą, įvertinti modelio konvergenciją ir stabilumą.
Mokymo pažangos stebėjimas naudojant "TensorBoard”
"PyTorch Lightning" turi integruotą "TensorBoard" palaikymą, leidžiantį stebėti įvairias metrikas, pavyzdžiui, nuostolių vertes, gradientus ir svorius viso mokymo proceso metu. Nors "PyTorch Lightning" automatiškai sukuria numatytąjį "TensorBoard" žurnalą, kai inicijuojate “Trainer” nenurodydami žurnalo, gera praktika yra aiškiai sukurti ir sukonfigūruoti žurnalą pagal savo poreikius.
Pateikiame pavyzdį, kaip sukurti "TensorBoardLogger" ir perduoti jį “Trainer”:
from pytorch_lightning import Trainer
from pytorch_lightning.loggers import TensorBoardLogger
from src import get_hparams()
hparams = get_hparams()
model = GAN(hparams)
logger = TensorBoardLogger(save_dir='logs', name='my_experiment')
trainer = Trainer(logger=logger, max_epochs=100, gpus=1)
trainer.fit(model, dataloader)
Norėdami registruoti generatoriaus ir diskriminatoriaus nuostolius, naudokite funkcijos training_step metodą self.log():
self.log('disc_loss', disc_loss)
self.log('gen_loss', gen_loss)
Šie žurnalai bus automatiškai vizualizuojami "TensorBoard", todėl galėsite stebėti mokymo eigą ir nustatyti galimas problemas.
Taip pat galite registruoti vaizdus, garsą ir kitus duomenis, kuriuos palaiko "TensorBoard", pvz:
self.logger.experiment.add_image("gen/image", image_tensor, self.global_step, dataformats='HWC')
self.logger.experiment.add_audio("gen/audio", audio_tensor, self.global_step, self.hps.data.sample_rate)
Daugiau informacijos apie prisijungimą "Pytorch Lightning" rasite oficialioje dokumentacijoje.
Generuojamų imčių vizualizavimas yra labai svarbus GAN efektyvumo vertinimo aspektas, todėl rekomenduojama stebėti generuojamų imčių kokybę ir įvairovę laikui bėgant.
Modelio konvergencijos ir stabilumo vertinimas
Gali būti sudėtinga įvertinti GAN konvergavimą ir stabilumą, nes nėra galutinės generatyvinių modelių našumo vertinimo metrikos. Tačiau galite naudoti keletą metodų, padedančių įvertinti bendrą modelio konvergavimą ir stabilumą:
Stebėkite nuostolių vertes: Stebėkite generatoriaus ir diskriminatoriaus nuostolių vertes viso mokymo proceso metu. Stabilios GAN abiejų modelių nuostolių reikšmės turėtų būti gana pastovios.
Vizuali apžiūra: Reguliariai tikrinkite gautus pavyzdžius, kad užtikrintumėte, jog laikui bėgant jų kokybė ir įvairovė gerėja. Tai gali padėti nustatyti režimo žlugimą ar kitas su mokymo procesu susijusias problemas.
Kiekybinis vertinimas: Naudokite kiekybines metrikas, tokias kaip “Inception Score” (IS), “Frechet Inception Distance” (FID) ar kitas konkrečiai sričiai būdingas metrikas, kad įvertintumėte sukurtų imčių kokybę ir įvairovę.
Atidžiai stebėdami ir vertindami savo GAN, galite gauti informacijos apie jo veikimą, nustatyti galimas problemas ir atlikti reikiamus koregavimus, kad pasiektumėte geresnių rezultatų.
Straipsnyje apžvelgiamas priešpriešinis mokymasis naudojant "PyTorch Lightning" - galingą ir lanksčią gilaus mokymosi sistemą, kuri supaprastina gilaus mokymosi modelių mokymo procesą. Išnagrinėjome unikalius generatoriaus ir diskriminatoriaus modelių mokymo vienu metu aspektus, įskaitant modelių apibrėžimą, mokymo ciklo, optimizatorių ir nuostolių funkcijų nustatymą, taip pat abiejų modelių veikimo stebėjimą ir vertinimą mokymo metu.
Ieškote pagalbos diegiant naują sprendimą savo verslo procesuose? Susisiekite su mūsų komanda!