-
Notifications
You must be signed in to change notification settings - Fork 34
Expand file tree
/
Copy pathbase.py
More file actions
146 lines (115 loc) · 5.29 KB
/
base.py
File metadata and controls
146 lines (115 loc) · 5.29 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Literal, Optional, Union
import numpy as np
from pydrantic import BaseConfig, ObjectConfig
from pydantic import BaseModel
from cartridges.clients.usage import Usage
from dataclasses import dataclass, asdict
class CartridgeConfig(BaseConfig):
id: str
source: Literal["huggingface", "wandb"]
force_redownload: bool = False
class ClientConfig(ObjectConfig):
_pass_as_config: bool = True
model_name: str
show_progress_bar: bool = False
def instantiate(self, *args, **kwargs) -> "Client":
return super().instantiate(*args, **kwargs)
class Client(ABC):
def __init__(self, config: ClientConfig):
self.config = config
@abstractmethod
async def chat(
self,
chats: List[List[Dict[str, Any]]],
temperature: float = 0.6,
stop: List[str] = [],
max_completion_tokens: Optional[int] = None,
frequency_penalty: float = 0.0,
top_logprobs: int = 1,
logprobs_start_message: Optional[int] = None,
modal_upstream_id: Optional[str] = None,
) -> ClientResponse:
raise NotImplementedError
@dataclass(slots=True)
class ClientResponse:
samples: List[ClientSample]
usage: Usage
timings: Optional[List[Dict[str, Any]]] = None
def to_dict(self):
return asdict(self)
@dataclass(slots=True)
class ClientSample:
text: str
token_ids: Optional[List[int]] = None
top_logprobs: Optional[TopLogprobs] = None
@dataclass(slots=True)
class FlatTopLogprobs:
"""
A flat / sparse view produced by `TopLogprobs.flatten`.
token_idx – 1-D [N] row number in the original matrix
token_id – 1-D [N] vocabulary id
logprobs – 1-D [N] natural-log probabilities
shape – tuple(int,int) – (num_tokens , num_top_logprobs)
"""
token_idx: np.ndarray
token_id: np.ndarray
logprobs: np.ndarray
shape: tuple[int, int]
# ──────────────────────────────────────────────────────────
# Inflate the sparse view back to dense [T , K] tensors.
# Missing entries: token_id = −1 , logprob = −∞
# ──────────────────────────────────────────────────────────
def reconstruct(self) -> "TopLogprobs":
T, K = self.shape
dense_logp = np.full((T, K), -1000.0, dtype=self.logprobs.dtype)
dense_ids = np.full((T, K), -1, dtype=self.token_id.dtype)
# The kept entries for each row always occupy the first *n*
# columns, where n = (token_idx == row).sum().
# We fill the matrix row-by-row (vectorised per row, tiny loop
# over T only – negligible cost).
for row in range(T):
row_mask = self.token_idx == row # boolean mask
n_keep = row_mask.sum()
if n_keep: # skip empty rows
dense_logp[row, :n_keep] = self.logprobs[row_mask]
dense_ids [row, :n_keep] = self.token_id[row_mask]
return TopLogprobs(logprobs=dense_logp, token_ids=dense_ids)
@dataclass(slots=True)
class TopLogprobs:
"""
logprobs – [num_tokens , num_top_logprobs] (sorted, ln p)
token_ids – [num_tokens , num_top_logprobs]
"""
logprobs: np.ndarray
token_ids: np.ndarray
# ──────────────────────────────────────────────────────────
# Flatten while retaining only the leading columns whose
# cumulative probability per row ≥ `threshold`
# ──────────────────────────────────────────────────────────
def flatten(self, threshold: float = 0.99) -> FlatTopLogprobs:
if self.logprobs.ndim != 2 or self.token_ids.ndim != 2:
raise ValueError("logprobs and token_ids must be 2-D arrays")
if self.logprobs.shape != self.token_ids.shape:
raise ValueError("Shapes of logprobs and token_ids differ")
if not (0.0 < threshold <= 1.0):
raise ValueError("threshold must be in (0, 1]")
T, K = self.logprobs.shape
# 1. cumulative probability mass (rows already sorted)
probs = np.exp(self.logprobs) # [T , K]
cum_mass = np.cumsum(probs, axis=1) # [T , K]
# 2. per-row cut-off index (inclusive)
cut_idx = (cum_mass >= threshold).argmax(axis=1) # [T]
# 3. build a boolean mask: keep columns 0 … cut_idx[row]
mask = np.arange(K) < (cut_idx[:, None] + 1) # [T , K]
# 4. flatten
token_idx = np.repeat(np.arange(T), K)[mask.ravel()] # [N]
token_id = self.token_ids[mask] # [N]
logprobs = self.logprobs[mask] # [N]
return FlatTopLogprobs(
token_idx=token_idx,
token_id=token_id,
logprobs=logprobs,
shape=(T, K),
)