xAI Grok
Kodlama
Ucretsiz

Runners - xAI Grok

xAI Grok için "Runners" örneği. Bu promptu kodlama görevleriniz için kullanın. AI'dan kod yazmasını, hata ayıklamasını veya optimizasyon önerileri almasını isteyin.

142 indirme
371 goruntuleme
51 begeni
17 Şub
Community

Prompt

Kopyalayip yapistirin

# Copyright 2024 X.AI Corp. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. import bisect import functools import logging import math import re from dataclasses import dataclass from typing import Any, Callable, NamedTuple, Optional, Tuple import haiku as hk import jax import jax.experimental.pjit as pjit import jax.numpy as jnp import numpy as np import sentencepiece from jax.experimental import mesh_utils from jax.sharding import PartitionSpec as P from jax.typing import ArrayLike import checkpoint as xai_checkpoint from model import ( LanguageModelConfig, LanguageModelOutput, TrainingState, apply_rules, Memory, KVMemory, ) logger = logging.getLogger(__name__) rank_logger = logging.getLogger("rank") TOP_K = 8 class SampleSettings(NamedTuple): temperature: ArrayLike nucleus_p: ArrayLike mask: ArrayLike # Whether a given batch element is actively used. [B] active: ArrayLike class SampleOutput(NamedTuple): token_id: ArrayLike prob: ArrayLike top_k_token_ids: ArrayLike top_k_probs: ArrayLike def insert_slice(memory: Memory, slice, length, i): slice = Memory( layers=[ KVMemory(layer.k, layer.v, step=jnp.array([length])) for layer in slice.layers ], ) return jax.tree_map(lambda m, u: jax.lax.dynamic_update_index_in_dim(m, u[0], i, axis=0), memory, slice) def pad_to_size(x, size): if x.shape[0] > size: # Left truncate if the context is too long. x = x[-size:] return np.pad(x, [0, size - x.shape[0]], mode="constant", constant_values=0) def top_p_filter(logits: jax.Array, top_p: jax.Array) -> jax.Array: """Performs nucleus filtering on logits.""" assert logits.ndim == top_p.ndim, f"Expected {logits.ndim} equal {top_p.ndim}" sorted_logits = jax.lax.sort(logits, is_stable=False) sorted_probs = jax.nn.softmax(sorted_logits) threshold_idx = jnp.argmax(jnp.cumsum(sorted_probs, -1) >= 1 - top_p, axis=-1) threshold_largest_logits = jnp.take_along_axis( sorted_logits, threshold_idx[..., jnp.newaxis], axis=-1 ) assert threshold_largest_logits.shape == logits.shape[:-1] + (1,) mask = logits >= threshold_largest_logits # Set unused logits to -inf. logits = jnp.where(mask, logits, -1e10) return logits def sample_token( rngs: jax.random.PRNGKey, lm_outputs: LanguageModelOutput, settings: SampleSettings, ) -> SampleOutput: # Expand the settings shape to match the logit shape. settings = SampleSettings( temperature=jnp.expand_dims(settings.temperature, (1, 2)), # Input [B], output [B, 1, 1]. nucleus_p=jnp.expand_dims(settings.nucleus_p, (1, 2)), # Input [B], output [B, 1, 1]. mask=jnp.expand_dims(settings.mask, 1), # Input [B, V], output [B, 1, V]. active=settings.active, # [B]. ) logits = lm_outputs.logits / settings.temperature.astype(lm_outputs.logits.dtype) # Mask out all disallowed tokens by assigning them a near-zero probability. logits = jnp.where(settings.mask, logits, -1e10) # Mask out all tokens that don't fall into the p-th percentile. logits = top_p_filter(logits, settings.nucleus_p.astype(logits.dtype)) new_token = jax.vmap(jax.random.categorical)(rngs, logits) probabilities = jax.nn.softmax(logits) token_prob = jnp.take_along_axis(probabilities, jnp.expand_dims(new_token, 1), axis=2) token_prob = jnp.squeeze(token_prob, 1) # Gather the top-k tokens and probabilities. top_k_probs, top_k_token_ids = jax.lax.top_k(probabilities, TOP_K) top_k_probs = jnp.squeeze(top_k_probs, 1) top_k_token_ids = jnp.squeeze(top_k_token_ids, 1) return SampleOutput( new_token, token_prob, top_k_token_ids, top_k_probs, ) @dataclass class ModelRunner: model: LanguageModelConfig bs_per_device: float = 2.0 load_rename_rules: Optional[list[tuple[str, str]]] = None load_exclude_rules: Optional[list[str]] = None rng_seed: int = 42 # Initial rng seed. transform_forward: bool = False checkpoint_path: str = "" def make_forward_fn(self, mesh: Any): def forward(tokens): out = self.model.make(mesh=mesh)(tokens) return out, None if self.transform_forward: forward = hk.transform(forward) return forward def initialize( self, init_data, local_mesh_config:

Test Edilmis

Gercek senaryolarda test edilmis ve dogrulanmis prompt

Optimize Edilmis

En iyi sonuclar icin optimize edilmis prompt yapisi

Hemen Kullanin

Kopyala-yapistir ile hemen kullanmaya baslayin