xAI Grok
Kodlama
UcretsizCheckpoint - xAI Grok
xAI Grok için "Checkpoint" örneği. Bu promptu kodlama görevleriniz için kullanın. AI'dan kod yazmasını, hata ayıklamasını veya optimizasyon önerileri almasını isteyin.
162 indirme
353 goruntuleme
37 begeni
17 Şub
Community
Prompt
Kopyalayip yapistirin
# Copyright 2024 X.AI Corp.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
import contextlib
import logging
import math
import os
import pickle
import re
import shutil
import sys
import tempfile
from concurrent.futures import ThreadPoolExecutor, wait
from typing import Any, Optional
import jax
import numpy as np
from jax.experimental import multihost_utils
from model import QuantizedWeight8bit
logger = logging.getLogger(__name__)
rank_logger = logging.getLogger("rank")
# Needed for loading the checkpoint with pickle.
sys.modules['__main__'].QuantizedWeight8bit = QuantizedWeight8bit
@contextlib.contextmanager
def copy_to_shm(file: str):
if file.startswith("/dev/shm/"):
# Nothing to do, the file is already in shared memory.
yield file
return
tmp_dir = "/dev/shm/"
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
try:
shutil.copyfile(file, tmp_path)
yield tmp_path
finally:
os.remove(tmp_path)
os.close(fd)
@contextlib.contextmanager
def copy_from_shm(file: str):
tmp_dir = "/dev/shm/"
fd, tmp_path = tempfile.mkstemp(dir=tmp_dir)
try:
yield tmp_path
shutil.copyfile(tmp_path, file)
finally:
os.remove(tmp_path)
os.close(fd)
def fast_unpickle(path: str) -> Any:
with copy_to_shm(path) as tmp_path:
with open(tmp_path, "rb") as f:
return pickle.load(f)
def fast_pickle(obj: Any, path: str) -> None:
with copy_from_shm(path) as tmp_path:
with open(tmp_path, "wb") as f:
pickle.dump(obj, f)
def load_tensors(shaped_arrays, directory, mesh_config, tensor_indices=None):
"""Loads a set of arrays."""
pool = ThreadPoolExecutor(max_workers=32)
fs = list()
num_tensors = 0
num_replicas = 1
data_model_shards = math.prod(mesh_config)
if tensor_indices is None:
iterator = enumerate(shaped_arrays)
else:
iterator = zip(tensor_indices, shaped_arrays)
for i, t in iterator:
if (i % num_replicas) == ((jax.process_index() // data_model_shards) % num_replicas):
idx = (
jax.process_index() // (num_replicas * data_model_shards) * data_model_shards
+ jax.process_index() % data_model_shards
)
fs.append(
pool.submit(fast_unpickle, os.path.join(directory, f"tensor{i:05d}_{idx:03d}"))
)
num_tensors += 1
else:
fs.append(pool.submit(np.zeros, t.shape, dtype=t.dtype))
wait(fs)
return [f.result() for f in fs]
def path_tuple_to_string(path: tuple) -> str:
pieces = []
for elem in path:
if isinstance(elem, jax.tree_util.DictKey):
pieces.append(elem.key)
elif isinstance(elem, jax.tree_util.GetAttrKey):
pieces.append(elem.name)
else:
assert isinstance(elem, (jax.tree_util.FlattenedIndexKey, jax.tree_util.SequenceKey))
return "/".join(pieces)
def get_load_path_str(
init_path_str: str,
load_rename_rules: Optional[list[tuple[str, str]]] = None,
load_exclude_rules: Optional[list[str]] = None,
) -> Optional[str]:
# Exclusion
if load_exclude_rules is not None:
for search_pattern in load_exclude_rules:
if re.search(search_pattern, init_path_str):
return None
# Renaming
load_path_str = init_path_str
if load_rename_rules is not None:
for search_pattern, replacement_pattern in load_rename_rules:
if re.search(search_pattern, load_path_str):
load_path_str = re.sub(search_pattern, replacement_pattern, load_path_str)
break
return load_path_str
def replace_with_load_state(
init_state: Any,
load_state: Any,
load_rename_rules: Optional[list[tuple[str, str]]] = None,
load_exclude_rules: Optional[list[str]] = None,
mesh_config: tuple = (1, 1),
) -> Any:
flatten_load, _ = jax.tree_util.tree_flatten_with_path(load_state)
flatten_init, structure_init = jax.tree_util.tree_flatten_with_path(init_state)
load_map = {path_tuple_to_string(path): tensor for path, tensor in flatten_load}
replaced = []
num_replicas = 1
data_model_shards = math.prod(mesh_config)
for i, (init_path, tensor) in enumerate(flatten_init):
init_path_str = path_tuple_to_string(init_path)
load_path_str = get_load_path_str(init_path_str, load_rename_rules,
Test Edilmis
Gercek senaryolarda test edilmis ve dogrulanmis prompt
Optimize Edilmis
En iyi sonuclar icin optimize edilmis prompt yapisi
Hemen Kullanin
Kopyala-yapistir ile hemen kullanmaya baslayin