-
Notifications
You must be signed in to change notification settings - Fork 635
/
Copy pathcompute_embeddings.yaml
65 lines (53 loc) · 1.79 KB
/
compute_embeddings.yaml
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
name: compute-law-embeddings
workdir: .
resources:
accelerators:
L4: 1
memory: 32+
any_of:
- use_spot: true
- use_spot: false
envs:
START_IDX: 0 # Will be overridden by batch_compute_vectors.py
END_IDX: 10000 # Will be overridden by batch_compute_vectors.py
MODEL_NAME: "Alibaba-NLP/gte-Qwen2-7B-instruct"
EMBEDDINGS_BUCKET_NAME: sky-rag-embeddings # Bucket name for storing embeddings
file_mounts:
/output:
name: ${EMBEDDINGS_BUCKET_NAME}
mode: MOUNT
setup: |
# Install dependencies for vLLM
pip install transformers==4.48.1 vllm==0.6.6.post1
# Install dependencies for embedding computation
pip install numpy pandas requests tqdm datasets
pip install nltk hf_transfer
run: |
# Initialize and download the model
HF_HUB_ENABLE_HF_TRANSFER=1 huggingface-cli download --local-dir /tmp/model $MODEL_NAME
# Start vLLM service in background
python -m vllm.entrypoints.openai.api_server \
--host 0.0.0.0 \
--model /tmp/model \
--max-model-len 3072 \
--task embed &
# Wait for vLLM to be ready by checking the health endpoint
echo "Waiting for vLLM service to be ready..."
while ! curl -s http://localhost:8000/health > /dev/null; do
sleep 5
echo "Still waiting for vLLM service..."
done
echo "vLLM service is ready!"
# Process the assigned range of documents
echo "Processing documents from $START_IDX to $END_IDX"
python scripts/compute_embeddings.py \
--output-path "/output/embeddings_${START_IDX}_${END_IDX}.parquet" \
--start-idx $START_IDX \
--end-idx $END_IDX \
--chunk-size 2048 \
--chunk-overlap 512 \
--vllm-endpoint http://localhost:8000 \
--batch-size 32
# Clean up vLLM service
pkill -f "python -m vllm.entrypoints.openai.api_server"
echo "vLLM service has been stopped"