Skip to content
Open
Show file tree
Hide file tree
Changes from 3 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion scripts/sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from src.data.commonsense_dataset import CommonSenseDataset
from src.data.utils import Preprocessor
from src.diffusion.model import DiDi, get_components
from src.sampling import sample
from src.pipeline.sampling import sample


def configure_arg_parser():
Expand Down
2 changes: 1 addition & 1 deletion scripts/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@
from src.data.reddit_dataset import RedditDataset
from src.diffusion.model import DiDi
from src.diffusion.model import get_components
from src.training import train_model
from src.pipeline.training import train_model
from src.utils import filter_warnings, setup_logger, zero_rank_info


Expand Down
73 changes: 73 additions & 0 deletions scripts/train_adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,73 @@
import argparse
from os import environ
from os.path import join

import torch
from omegaconf import OmegaConf
from torch.utils.data import DataLoader

from src.conditioning.adapter import Adapter
from src.data.convai2_dataset import ConvAI2Dataset
from src.diffusion.model import DiDi
from src.pipeline.training import train_model
from src.utils import filter_warnings, setup_logger, zero_rank_info


def configure_arg_parser():
parser = argparse.ArgumentParser()
parser.add_argument("config_path", type=str, help="Path to YAML config file")
parser.add_argument("dataset_dir", type=str, help="Path to dataset directory")
parser.add_argument("model_path", type=str, help="Path to DiDi model")
parser.add_argument("--condition", type=str, default="other", help="Type of persona")
return parser


def main(config_path: str, dataset_dir: str, model_path: str, condition: str):
filter_warnings()
setup_logger()
environ["TOKENIZERS_PARALLELISM"] = "false"

torch.set_float32_matmul_precision("high")

config = OmegaConf.load(config_path)
zero_rank_info(f"Loaded config:\n{OmegaConf.to_yaml(config, resolve=False, sort_keys=True)}")

train_dataset = ConvAI2Dataset(
join(dataset_dir, f"train_{condition}_revised_no_cands.txt"), config.base_name, **config.dataset
)
val_dataset = ConvAI2Dataset(
join(dataset_dir, f"valid_{condition}_revised_no_cands.txt"), config.base_name, **config.dataset
)

train_dataloader = DataLoader(
train_dataset,
batch_size=config.batch_size,
collate_fn=train_dataset.collate_fn,
pin_memory=True,
num_workers=1,
)

val_dataloader = DataLoader(
val_dataset,
batch_size=config.val_batch_size,
collate_fn=val_dataset.collate_fn,
pin_memory=True,
num_workers=1,
)

didi = DiDi.load_from_checkpoint(model_path)
model = Adapter(didi)

train_model(
model,
train_dataloader,
val_dataloader,
config.trainer,
seed=config.seed,
save_interval=config.save_interval,
)


if __name__ == "__main__":
_args = configure_arg_parser().parse_args()
main(**vars(_args))
Empty file added src/conditioning/__init__.py
Empty file.
101 changes: 101 additions & 0 deletions src/conditioning/adapter.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
import torch
from lightning import LightningModule
from torch import nn

from src.diffusion.model import DiDi
from src.diffusion.utils import get_diffusion_variables, get_x0
from src.pipeline.utils import calculate_train_step, freeze_params, get_cached_content, get_optimizers


class AdapterBlock(nn.Module):
def __init__(self, input_dim: int, num_heads: int):
super().__init__()
self.attention = nn.MultiheadAttention(embed_dim=input_dim, num_heads=num_heads)
self.query = nn.Linear(input_dim, input_dim)
self.key = nn.Linear(input_dim, input_dim)
self.value = nn.Linear(input_dim, input_dim)

def forward(self, hidden_states, encoder_hidden_states, encoder_attention_mask):
query = self.query(hidden_states)
key = self.key(encoder_hidden_states)
value = self.value(encoder_hidden_states)
return self.attention(key, query, value, need_weights=False, key_padding_mask=encoder_attention_mask > 0)


class Adapter(LightningModule):
def __init__(self, didi: DiDi, lr: float = 0.001, warmup_steps: int = 1, min_lr: float = None):
super().__init__()
self.didi = didi
freeze_params(self.didi)

self.decoder_layers = []
adapter_layers = []
for layer in didi.decoder.encoder.layer:
self.decoder_layers.append(layer)
adapter_layers.append(AdapterBlock(layer.output.dense.out_features, 1))
Comment thread
rrevoid marked this conversation as resolved.
Outdated

self.adapter_layers = nn.ModuleList(adapter_layers)

self.lr, self.warmup, self.min_lr = lr, warmup_steps, min_lr

def configure_optimizers(self):
return get_optimizers(self)

def forward(
self,
encoder_input_ids: torch.Tensor = None,
encoder_attention_mask: torch.Tensor = None,
decoder_inputs_embeds: torch.Tensor = None,
condition_input_ids: torch.Tensor = None,
condition_attention_mask: torch.Tensor = None,
time_ids: torch.Tensor = None,
context: torch.Tensor = None,
condition: torch.Tensor = None,
):
if encoder_input_ids is None and context is None:
raise ValueError("Either `encoder_input_ids` or `context` must be provided.")

if condition_input_ids is None and condition is None:
raise ValueError("Either `condition_input_ids` or `condition` must be provided.")

context = context or get_cached_content(self.didi, encoder_input_ids, encoder_attention_mask)
condition = condition or get_cached_content(self.didi, condition_input_ids, condition_attention_mask)

time_embeds = self.didi.time_embeds(time_ids)
hidden_states = decoder_inputs_embeds + time_embeds

for decoder_layer, adapter_layer in zip(self.decoder_layers, self.adapter_layers):
output = decoder_layer(
hidden_states=hidden_states,
encoder_hidden_states=context,
encoder_attention_mask=encoder_attention_mask,
)[0]
hidden_states = adapter_layer(
hidden_states=output,
encoder_hidden_states=condition,
encoder_attention_mask=condition_attention_mask,
)[0]

return hidden_states, context, condition

def training_step(self, batch: list, batch_idx: int):
raw_context, target, condition = batch
emb = self.didi.emb(target.input_ids)
x_0 = get_x0(emb, self.didi.std_0)
noise = torch.randn_like(x_0)

# x: [batch size; seq len; emb dim], t: [batch size]
x_t, t = get_diffusion_variables(self.didi.diffusion_steps, x_0, self.didi.sigmas, noise)

x_0_hat, *_ = self(
encoder_input_ids=raw_context.input_ids,
encoder_attention_mask=raw_context.attention_mask,
decoder_inputs_embeds=x_t,
time_ids=t,
condition_input_ids=condition.input_ids,
condition_attention_mask=condition.attention_mask,
) # [batch size; seq len; emb dim]

loss, metrics = calculate_train_step(self.didi, emb, x_0, x_0_hat, target, t)
self.log_dict(metrics, sync_dist=True, on_step=True, on_epoch=False)
return loss
72 changes: 50 additions & 22 deletions src/data/convai2_dataset.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional

from loguru import logger
from torch.utils.data import Dataset
from tqdm.auto import tqdm
from transformers import AutoTokenizer
from src.data.utils import Preprocessor


@dataclass
Expand All @@ -15,25 +17,48 @@ class ConvAI2Dialog:
partner_persona: Optional[list[str]] = None


class Conditions(Enum):
NONE = 0
YOUR = 1
PARTNERS = 2


def get_condition(path: str):
if "none" in path:
return Conditions.NONE
elif "self" in path:
return Conditions.YOUR
else:
return Conditions.PARTNERS
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This logic is too complicated. Since we are only collecting personas in the collect function, let's pass the required arguments directly to it.

def collate_fn(..., return_my_persona, return_partner_persona)

And also let's return dict, e.g.,

{
    "context": self.context_tokenizer(str_contexts, max_length=self.max_context_len, **self.tokenizer_kwargs)
    "candidates":  self.candidate_tokenizer(str_candidates, max_length=self.max_target_len, **self.tokenizer_kwargs)
}



class ConvAI2Dataset(Dataset):
_YOUR_PERSONA_PREFIX = "your persona: "
_PARTNER_PERSONA_PREFIX = "partner's persona: "

def __init__(self, path, tokenizer_name, max_context_len, max_target_len=None, have_candidates=True):
def __init__(self, path, tokenizer_name, max_context_len, max_target_len=None, max_condition_len=None):
self.dataset = []
self.num_dialogs = 0

self.context_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, truncation_side="left")
self.candidate_tokenizer = AutoTokenizer.from_pretrained(tokenizer_name)

bos = self.context_tokenizer.bos_token
eos = self.context_tokenizer.eos_token
self.tokenizer_kwargs = {
"padding": "max_length",
"truncation": True,
"return_tensors": "pt",
"add_special_tokens": False,
}
preprocessor = Preprocessor(tokenizer_name)
bos = preprocessor.bos
eos = preprocessor.eos

self.max_context_len = max_context_len
self.max_target_len = max_target_len or max_context_len
self.max_condition_len = max_condition_len or max_context_len

self.have_candidates = have_candidates
self.have_candidates = not "no_cands" in path
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why this should be in the path? We use the same dataset for working with and without candidates

self.vocab_size = self.context_tokenizer.vocab_size
self.condition = get_condition(path)

logger.info(f"Loading dataset from '{path}'")
with open(path, "r") as f:
Expand All @@ -53,7 +78,7 @@ def __init__(self, path, tokenizer_name, max_context_len, max_target_len=None, h
partner_persona.append(line[len(self._PARTNER_PERSONA_PREFIX) :])
continue

if have_candidates:
if self.have_candidates:
utterance1, utterance2, _, candidates_str = line.split("\t")
else:
utterance1, utterance2, *_ = line.split("\t")
Expand All @@ -79,26 +104,29 @@ def collate_fn(self, samples: list[ConvAI2Dialog], return_all_candidates: bool =
return_all_candidates = self.have_candidates & return_all_candidates
str_contexts = [" ".join(sample.context) for sample in samples]
# [batch size, context seq len]
b_contexts = self.context_tokenizer(
str_contexts,
max_length=self.max_context_len,
padding=True,
truncation=True,
return_tensors="pt",
add_special_tokens=False,
).input_ids
b_contexts = self.context_tokenizer(str_contexts, max_length=self.max_context_len, **self.tokenizer_kwargs)

str_conditions = []
if self.condition is Conditions.YOUR:
str_conditions = [" ".join(sample.my_persona) for sample in samples] # type: ignore
Copy link
Copy Markdown
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why are you ignoring the mypy here? What error does it raise?

Copy link
Copy Markdown
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It raises Argument 1 to "join" of "str" has incompatible type "Optional[List[str]]"; expected "Iterable[str]" [arg-type] error. Although the condition guarantees the existence of the my_persona attribute.

elif self.condition is Conditions.PARTNERS:
str_conditions = [" ".join(sample.partner_persona) for sample in samples] # type: ignore

if return_all_candidates:
str_candidates = [it for sample in samples for it in sample.candidates]
else:
str_candidates = [sample.candidates[0] for sample in samples]

# Tokenizer truncates on the left, but for candidates we want to truncate on the right
b_candidates = self.candidate_tokenizer(
str_candidates, padding="max_length", return_tensors="pt", add_special_tokens=False
).input_ids
b_candidates = b_candidates[:, : self.max_target_len]
# [batch size, # candidates, candidates seq len]
b_candidates = b_candidates.view(len(samples), -1, b_candidates.size(1))

return b_contexts, b_candidates.squeeze(1)
b_candidates = self.candidate_tokenizer(str_candidates, max_length=self.max_target_len, **self.tokenizer_kwargs)
# b_candidates = b_candidates[:, : self.max_target_len]
# # [batch size, # candidates, candidates seq len]
# b_candidates = b_candidates.view(len(samples), -1, b_candidates.size(1))

if self.condition is Conditions.NONE:
return b_contexts, b_candidates # .squeeze(1)
else:
b_conditions = self.candidate_tokenizer(
str_conditions, max_length=self.max_condition_len, **self.tokenizer_kwargs
)
return b_contexts, b_candidates, b_conditions # .squeeze(1)
Loading