-
Notifications
You must be signed in to change notification settings - Fork 179
/
Copy pathGPT2.swift
95 lines (81 loc) · 3.03 KB
/
GPT2.swift
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
//
// GPT2.swift
// CoreMLGPT2
//
// Created by Julien Chaumond on 19/07/2019.
// Copyright © 2019 Hugging Face. All rights reserved.
//
import Foundation
import CoreML
class GPT2 {
enum DecodingStrategy {
/// At each time step, we select the most likely next token
case greedy
/// Sample only from the top-k most-probable tokens (k is a hyper-parameter).
case topK(Int)
/// Sample from the top tokens with a cumulative probability just above a threshold (nucleus/top-p).
case topP(Double)
}
private let model = distilgpt2_64_6()
public let tokenizer = GPT2Tokenizer()
public let seqLen = 64
private let strategy: DecodingStrategy
init(strategy: DecodingStrategy = .greedy) {
self.strategy = strategy
}
/// Main prediction loop:
/// Predict next token from array of previous tokens.
/// - featurization
/// - model inference
/// - Decoding according to the model's `strategy`
func predict(tokens: [Int]) -> Int {
let maxTokens = (tokens.count > seqLen)
? Array(tokens[..<seqLen])
: tokens
/// Pad input_ids on the right, up to `seqLen`:
let input_ids = MLMultiArray.from(
maxTokens + Array(repeating: 0, count: seqLen - maxTokens.count)
)
let position_ids = MLMultiArray.from(
Array(0..<seqLen)
)
let output = try! model.prediction(input_ids: input_ids, position_ids: position_ids)
let outputLogits = MLMultiArray.slice(
output.output_logits,
indexing: [.select(0), .select(maxTokens.count - 1), .slice, .select(0), .select(0)]
)
switch strategy {
case .greedy:
let nextToken = Math.argmax(outputLogits)
return nextToken.0
case .topK(let k):
let logits = MLMultiArray.toDoubleArray(outputLogits)
let topk = Math.topK(arr: logits, k: k)
let sampleIndex = Math.sample(indexes: topk.indexes, probs: topk.probs)
return sampleIndex
case .topP(_):
fatalError("topP is not implemented yet")
}
}
/// Main generation loop.
///
/// Will generate next `nTokens` (defaults to 10).
/// Calls an incremental `callback` for each new token, then returns the generated string at the end.
///
func generate(text: String, nTokens: Int = 10, callback: ((String, Double) -> Void)?) -> String {
var tokens = tokenizer.encode(text: text)
var newTokens: [Int] = []
for i in 0..<nTokens {
let (nextToken, time) = Utils.time {
return predict(tokens: tokens)
}
tokens.append(nextToken)
newTokens.append(nextToken)
print("🦄 <\(time)s>", i, nextToken, tokens.count)
callback?(
tokenizer.decode(tokens: newTokens), time
)
}
return tokenizer.decode(tokens: newTokens)
}
}