-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy path__init__.py
155 lines (122 loc) · 4.18 KB
/
__init__.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
"""Concept Space Traversal plugin.
| Copyright 2017-2023, Voxel51, Inc.
| `voxel51.com <https://voxel51.com/>`_
|
"""
from scipy import linalg
import fiftyone as fo
import fiftyone.core.storage as fos
import fiftyone.operators as foo
import fiftyone.operators.types as types
def get_valid_indexes(dataset):
valid_indexes = []
for br in dataset.list_brain_runs():
bri = dataset.get_brain_info(br).config
if (
("Similarity" in bri.cls)
and bri.supports_prompts
and bri.metric == "cosine"
):
valid_indexes.append(br)
return valid_indexes
def _normalize(embedding):
return embedding / linalg.norm(embedding)
def generate_destination_vector(index, sample_id, concepts, text_scale):
sample_embedding = index.get_embeddings([sample_id])[0][0]
model = index.get_model()
concept_embedding = sum(
model.embed_prompts([concept["concept"]])[0] * concept["strength"]
for concept in concepts
)
dest_vec = _normalize(
sample_embedding + text_scale * _normalize(concept_embedding)
)
return dest_vec
def run_traversal(ctx):
dataset = ctx.dataset
index_name = ctx.params.get("index")
index = dataset.load_brain_results(index_name)
concepts = ctx.params.get("concepts")
text_scale = ctx.params.get("text_scale")
sample_id = ctx.params.get("sample")
dest_vec = generate_destination_vector(
index, sample_id, concepts, text_scale
)
view = dataset.sort_by_similarity(dest_vec, brain_key=index.key, k=25)
return view
class OpenTraversalPanel(foo.Operator):
@property
def config(self):
return foo.OperatorConfig(
name="open_traversal_panel",
label="Concept Traversal: open traversal panel",
icon="/assets/mesh_dark.svg",
)
def resolve_placement(self, ctx):
return types.Placement(
types.Places.SAMPLES_GRID_SECONDARY_ACTIONS,
types.Button(
label="Open Traversal Panel",
icon_dark="/assets/mesh_dark.svg",
icon_light="/assets/mesh_light.svg",
icon="/assets/mesh_dark.svg",
prompt=False,
),
)
def execute(self, ctx):
ctx.trigger(
"open_panel",
params=dict(
name="TraversalPanel", isActive=True, layout="horizontal"
),
)
class RunTraversal(foo.Operator):
@property
def config(self):
return foo.OperatorConfig(
name="traverser",
label="Traverse",
unlisted=True,
)
def resolve_input(self, ctx):
inputs = types.Object()
inputs.str("sample", label="Sample ID", required=True)
inputs.str("index", label="Brain Key", required=True)
inputs.float("text_scale", label="Text Scale", required=True)
list_row = types.Object()
list_row_cell = types.View(space=8)
list_row.str("concept", label="Concept", view=list_row_cell)
inputs.list("concepts", list_row, label="Concepts")
return types.Property(inputs)
def execute(self, ctx):
view = run_traversal(ctx)
ctx.ops.set_view(view=view)
class GetSampleURL(foo.Operator):
@property
def config(self):
return foo.OperatorConfig(
name="get_sample_url",
label="Concept Traversal: Get sample URL",
unlisted=True,
)
def execute(self, ctx):
try:
sample_id = ctx.params.get("id", None)
sample = ctx.dataset[sample_id]
sample_filepath = sample.filepath
try:
# pylint: disable=no-member
sample_filepath = fos.get_url(sample_filepath)
except:
address = fo.config.default_app_address
port = fo.config.default_app_port
sample_filepath = (
f"http://{address}:{port}/media?filepath={sample_filepath}"
)
return {"url": sample_filepath}
except:
return {}
def register(p):
p.register(RunTraversal)
p.register(OpenTraversalPanel)
p.register(GetSampleURL)