diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..88b8d07 --- /dev/null +++ b/.gitignore @@ -0,0 +1,2 @@ +/.classpath +/.project diff --git a/Llama3.java b/Llama3.java old mode 100755 new mode 100644 index 25a495f..bd3d4ab --- a/Llama3.java +++ b/Llama3.java @@ -85,44 +85,45 @@ static void runInteractive(Llama model, Sampler sampler, Options options) { conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.SYSTEM, options.systemPrompt()))); } int startPosition = 0; - Scanner in = new Scanner(System.in); - while (true) { - System.out.print("> "); - System.out.flush(); - String userText = in.nextLine(); - if (List.of("quit", "exit").contains(userText)) { - break; - } - if (state == null) { - state = model.createNewState(); - } - conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText))); - conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); - Set stopTokens = chatFormat.getStopTokens(); - List responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler, options.echo(), token -> { - if (options.stream()) { - if (!model.tokenizer().isSpecialToken(token)) { - System.out.print(model.tokenizer().decode(List.of(token))); - } - } - }); - // Include stop token in the prompt history, but not in the response displayed to the user. - conversationTokens.addAll(responseTokens); - startPosition = conversationTokens.size(); - Integer stopToken = null; - if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { - stopToken = responseTokens.getLast(); - responseTokens.removeLast(); - } - if (!options.stream()) { - String responseText = model.tokenizer().decode(responseTokens); - System.out.println(responseText); - } - if (stopToken == null) { - System.err.println("Ran out of context length..."); - break; - } - } + try (Scanner in = new Scanner(System.in)) { + while (true) { + System.out.print("> "); + System.out.flush(); + String userText = in.nextLine(); + if (List.of("quit", "exit").contains(userText)) { + break; + } + if (state == null) { + state = model.createNewState(); + } + conversationTokens.addAll(chatFormat.encodeMessage(new ChatFormat.Message(ChatFormat.Role.USER, userText))); + conversationTokens.addAll(chatFormat.encodeHeader(new ChatFormat.Message(ChatFormat.Role.ASSISTANT, ""))); + Set stopTokens = chatFormat.getStopTokens(); + List responseTokens = Llama.generateTokens(model, state, startPosition, conversationTokens.subList(startPosition, conversationTokens.size()), stopTokens, options.maxTokens(), sampler, options.echo(), token -> { + if (options.stream()) { + if (!model.tokenizer().isSpecialToken(token)) { + System.out.print(model.tokenizer().decode(List.of(token))); + } + } + }); + // Include stop token in the prompt history, but not in the response displayed to the user. + conversationTokens.addAll(responseTokens); + startPosition = conversationTokens.size(); + Integer stopToken = null; + if (!responseTokens.isEmpty() && stopTokens.contains(responseTokens.getLast())) { + stopToken = responseTokens.getLast(); + responseTokens.removeLast(); + } + if (!options.stream()) { + String responseText = model.tokenizer().decode(responseTokens); + System.out.println(responseText); + } + if (stopToken == null) { + System.err.println("Ran out of context length..."); + break; + } + } + } } static void runInstructOnce(Llama model, Sampler sampler, Options options) { @@ -346,7 +347,8 @@ public int byteSize() { } } - private void loadModelImpl(FileChannel fileChannel) throws IOException { + @SuppressWarnings("preview") + private void loadModelImpl(FileChannel fileChannel) throws IOException { // The header of the file. readHeader(fileChannel); // gguf_header_t header; // Tensor infos, which can be used to locate the tensor data. @@ -726,7 +728,7 @@ private static Tokenizer createTokenizer(Map metadata, Vocabular int allTokens = vocabulary.size(); int baseTokens = 128000; // assume all tokens after the base ones are special. - int reservedSpecialTokens = allTokens - baseTokens; + List specialTokensList = Arrays.stream(vocabulary.tokens(), baseTokens, allTokens).toList(); assert specialTokensList.stream().allMatch(token -> vocabulary.getIndex(token).isPresent()); @@ -1868,9 +1870,11 @@ public static Pair precomputeFreqsCis(int contextLength, int h float loFreqWavelen = oldContextLength / loFreqFactor; float hiFreqWavelen = oldContextLength / hiFreqFactor; float wavelen = (float) (2.0 * Math.PI / freq); + + //This doesn't do anything, so it triggers a warning. if (wavelen < hiFreqWavelen) { freq = freq; - } else if (wavelen > loFreqWavelen) { + } else if (wavelen > loFreqWavelen) { freq = freq / scaleFactor; } else { float smooth = (oldContextLength / wavelen - loFreqFactor) / (hiFreqFactor - loFreqFactor);