diff --git a/site/en/gemma/docs/pytorch_gemma.ipynb b/site/en/gemma/docs/pytorch_gemma.ipynb index 3286096c7..b5b432de9 100644 --- a/site/en/gemma/docs/pytorch_gemma.ipynb +++ b/site/en/gemma/docs/pytorch_gemma.ipynb @@ -220,9 +220,11 @@ "outputs": [], "source": [ "import torch\n", + "from gemma import config as gemma_config\n", "\n", "# Set up model config.\n", "model_config = get_config_for_2b() if \"2b\" in VARIANT else get_config_for_7b()\n", + "model_config.architecture = gemma_config.Architecture.GEMMA_1\n", "model_config.tokenizer = tokenizer_path\n", "model_config.quant = 'quant' in VARIANT\n", "\n",